Example #1
0
def prepare_batch(batch, model):
    # To enable copy attn, collect source map and alignment info
    batch_inputs = dict()

    if model.args.copy_attn:
        assert 'src_map' in batch and 'alignment' in batch
        source_map = make_src_map(batch['src_map'])
        source_map = source_map.cuda(non_blocking=True) if args.cuda \
            else source_map
        if batch['alignment'][0] is not None:
            alignment = align(batch['alignment'])
            alignment = alignment.cuda(non_blocking=True) if args.cuda \
                else alignment
        else:
            alignment = None
        blank, fill = collapse_copy_scores(model.tgt_dict, batch['src_vocab'])
    else:
        source_map, alignment = None, None
        blank, fill = None, None

    batch_inputs['src_map'] = source_map
    batch_inputs['alignment'] = alignment
    batch_inputs['blank'] = blank
    batch_inputs['fill'] = fill

    code_word_rep = batch['code_word_rep']
    code_char_rep = batch['code_char_rep']
    code_type_rep = batch['code_type_rep']
    code_mask_rep = batch['code_mask_rep']
    code_len = batch['code_len']
    if args.cuda:
        code_len = batch['code_len'].cuda(non_blocking=True)
        if code_word_rep is not None:
            code_word_rep = code_word_rep.cuda(non_blocking=True)
        if code_char_rep is not None:
            code_char_rep = code_char_rep.cuda(non_blocking=True)
        if code_type_rep is not None:
            code_type_rep = code_type_rep.cuda(non_blocking=True)
        if code_mask_rep is not None:
            code_mask_rep = code_mask_rep.cuda(non_blocking=True)

    batch_inputs['code_word_rep'] = code_word_rep
    batch_inputs['code_char_rep'] = code_char_rep
    batch_inputs['code_type_rep'] = code_type_rep
    batch_inputs['code_mask_rep'] = code_mask_rep
    batch_inputs['code_len'] = code_len
    return batch_inputs
Example #2
0
    def predict(self, ex, replace_unk=False):
        """Forward a batch of examples only to get predictions.
        Args:
            ex: the batch examples
            replace_unk: replace `unk` tokens while generating predictions
            src_raw: raw source (passage); required to replace `unk` term
        Output:
            predictions: #batch predicted sequences
        """
        # Eval mode
        self.network.eval()

        source_map, alignment = None, None
        blank, fill = None, None
        # To enable copy attn, collect source map and alignment info
        if self.args.copy_attn:
            assert 'src_map' in ex and 'alignment' in ex

            source_map = make_src_map(ex['src_map'])
            source_map = source_map.cuda(non_blocking=True) if self.use_cuda \
                else source_map

            if ex['alignment'][0][0] is not None:
                alignment = align(ex['alignment'])
                alignment = alignment.cuda(non_blocking=True) if self.use_cuda \
                    else alignment

            blank, fill = collapse_copy_scores(self.tgt_dict, ex['src_vocab'])

        code_word_rep = ex['code_word_rep']
        code_char_rep = ex['code_char_rep']
        code_type_rep = ex['code_type_rep']
        code_mask_rep = ex['code_mask_rep']
        code_len = ex['code_len']
        if self.use_cuda:
            code_len = code_len.cuda(non_blocking=True)
            if code_word_rep is not None:
                code_word_rep = code_word_rep.cuda(non_blocking=True)
            if code_char_rep is not None:
                code_char_rep = code_char_rep.cuda(non_blocking=True)
            if code_type_rep is not None:
                code_type_rep = code_type_rep.cuda(non_blocking=True)
            if code_mask_rep is not None:
                code_mask_rep = code_mask_rep.cuda(non_blocking=True)

        decoder_out = self.network(code_word_rep=code_word_rep,
                                   code_char_rep=code_char_rep,
                                   code_type_rep=code_type_rep,
                                   code_len=code_len,
                                   summ_word_rep=None,
                                   summ_char_rep=None,
                                   summ_len=None,
                                   tgt_seq=None,
                                   src_map=source_map,
                                   alignment=alignment,
                                   max_len=self.args.max_tgt_len,
                                   src_dict=self.src_dict,
                                   tgt_dict=self.tgt_dict,
                                   blank=blank,
                                   fill=fill,
                                   source_vocab=ex['src_vocab'],
                                   code_mask_rep=code_mask_rep)

        predictions = tens2sen(decoder_out['predictions'], self.tgt_dict,
                               ex['src_vocab'])
        if replace_unk:
            for i in range(len(predictions)):
                enc_dec_attn = decoder_out['attentions'][i]
                if self.args.model_type == 'transformer':
                    assert enc_dec_attn.dim() == 3
                    enc_dec_attn = enc_dec_attn.mean(1)
                predictions[i] = replace_unknown(predictions[i],
                                                 enc_dec_attn,
                                                 src_raw=ex['code_tokens'][i])
                if self.args.uncase:
                    predictions[i] = predictions[i].lower()

        targets = [summ for summ in ex['summ_text']]
        return predictions, targets, decoder_out['copy_info']
Example #3
0
    def update(self, ex):
        """Forward a batch of examples; step the optimizer to update weights."""
        if not self.optimizer:
            raise RuntimeError('No optimizer set.')

        # Train mode
        self.network.train()

        source_map, alignment = None, None
        blank, fill = None, None
        # To enable copy attn, collect source map and alignment info
        if self.args.copy_attn:
            assert 'src_map' in ex and 'alignment' in ex

            source_map = make_src_map(ex['src_map'])
            source_map = source_map.cuda(non_blocking=True) if self.use_cuda \
                else source_map

            alignment = align(ex['alignment'])
            alignment = alignment.cuda(non_blocking=True) if self.use_cuda \
                else alignment

            blank, fill = collapse_copy_scores(self.tgt_dict, ex['src_vocab'])

        code_word_rep = ex['code_word_rep']
        code_char_rep = ex['code_char_rep']
        code_type_rep = ex['code_type_rep']
        code_mask_rep = ex['code_mask_rep']
        code_len = ex['code_len']
        summ_word_rep = ex['summ_word_rep']
        summ_char_rep = ex['summ_char_rep']
        summ_len = ex['summ_len']
        tgt_seq = ex['tgt_seq']

        if any(l is None for l in ex['language']):
            ex_weights = None
        else:
            ex_weights = [
                self.args.dataset_weights[lang] for lang in ex['language']
            ]
            ex_weights = torch.FloatTensor(ex_weights)

        if self.use_cuda:
            code_len = code_len.cuda(non_blocking=True)
            summ_len = summ_len.cuda(non_blocking=True)
            tgt_seq = tgt_seq.cuda(non_blocking=True)
            if code_word_rep is not None:
                code_word_rep = code_word_rep.cuda(non_blocking=True)
            if code_char_rep is not None:
                code_char_rep = code_char_rep.cuda(non_blocking=True)
            if code_type_rep is not None:
                code_type_rep = code_type_rep.cuda(non_blocking=True)
            if code_mask_rep is not None:
                code_mask_rep = code_mask_rep.cuda(non_blocking=True)
            if summ_word_rep is not None:
                summ_word_rep = summ_word_rep.cuda(non_blocking=True)
            if summ_char_rep is not None:
                summ_char_rep = summ_char_rep.cuda(non_blocking=True)
            if ex_weights is not None:
                ex_weights = ex_weights.cuda(non_blocking=True)

        # Run forward
        net_loss = self.network(code_word_rep=code_word_rep,
                                code_char_rep=code_char_rep,
                                code_type_rep=code_type_rep,
                                code_len=code_len,
                                summ_word_rep=summ_word_rep,
                                summ_char_rep=summ_char_rep,
                                summ_len=summ_len,
                                tgt_seq=tgt_seq,
                                src_map=source_map,
                                alignment=alignment,
                                src_dict=self.src_dict,
                                tgt_dict=self.tgt_dict,
                                max_len=self.args.max_tgt_len,
                                blank=blank,
                                fill=fill,
                                source_vocab=ex['src_vocab'],
                                code_mask_rep=code_mask_rep,
                                example_weights=ex_weights)

        loss = net_loss['ml_loss'].mean() if self.parallel \
            else net_loss['ml_loss']
        loss_per_token = net_loss['loss_per_token'].mean() if self.parallel \
            else net_loss['loss_per_token']
        ml_loss = loss.item()
        loss_per_token = loss_per_token.item()
        loss_per_token = 10 if loss_per_token > 10 else loss_per_token
        perplexity = math.exp(loss_per_token)

        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        loss.backward()

        clip_grad_norm_(self.network.parameters(), self.args.grad_clipping)
        self.optimizer.step()
        self.optimizer.zero_grad()

        self.updates += 1
        return {'ml_loss': ml_loss, 'perplexity': perplexity}