Ejemplo n.º 1
0
    def forward(self, inputs):
        words, masks, pos, deprel, head, subj_pos, obj_pos = inputs  # unpack
        src_mask = (words != constant.PAD_ID).unsqueeze(-2)

        word_embs = self.emb(words)
        embs = [word_embs]

        if self.opt['pos_dim'] > 0:
            embs += [self.pos_emb(pos)]
        embs = torch.cat(embs, dim=2)
        embs = self.in_drop(embs)

        if self.opt.get('rnn', False):
            embs = self.input_W_R(embs)
            gcn_inputs = self.rnn_drop(
                self.encode_with_rnn(embs, masks,
                                     words.size()[0]))
        else:
            gcn_inputs = embs

        gcn_inputs = self.input_W_G(gcn_inputs)

        layer_list = []
        outputs = gcn_inputs

        adj_list = None

        for i in range(len(self.layers)):
            if i == 0 or i == 3:
                adj_list = self.layers[i](outputs, src_mask)
                if self.opt['data_dir'] != 'dataset/semeval':
                    for j in range(len(adj_list)):
                        if i == 3:
                            adj_list[j] = entmax_bisect(
                                adj_list[j], self.alpha_list[self.heads + j])
                        else:
                            adj_list[j] = entmax_bisect(
                                adj_list[j], self.alpha_list[j])
            else:
                outputs = self.layers[i](adj_list, outputs)
                layer_list.append(outputs)

        aggregate_out = torch.cat(layer_list, dim=2)
        dcgcn_output = self.aggregate_W(aggregate_out)
        adj = torch.stack(adj_list, dim=1).sum(dim=1)

        mask = (adj.sum(2) + adj.sum(1)).eq(0).unsqueeze(2)

        return dcgcn_output, mask
Ejemplo n.º 2
0
def entmax(input_ids, tokenizer, model, prompt, epoch=None, alpha=1.5, max_length=50):

    new_input_ids = deepcopy(input_ids)
    alpha = torch.tensor(alpha, requires_grad=True)

    log = []

    # print(input_ids)

    for _ in range(max_length):

        prediction_scores = model(new_input_ids)[0][0][-1]
        prediction_prob = entmax_bisect(prediction_scores, alpha)
        candidates = torch.nonzero(prediction_prob)
        next_token_id = candidates[torch.randint(candidates.size()[0], (1,))]

        # print(tokenizer.decode(new_input_ids[0], skip_special_tokens=False))
        # print(tokenizer.decode(next_token_id[0], skip_special_tokens=False),':\t', prediction_prob[next_token_id].data[0][0])

        new_input_ids = torch.cat((new_input_ids, next_token_id), dim=1)

        log.append((tokenizer.decode(next_token_id[0], skip_special_tokens=False), prediction_prob[next_token_id].item()))


    # pprint(log)
    output_sent = tokenizer.decode(new_input_ids[0], skip_special_tokens=False)
    # if epoch is not None:
    #     prompt = f'epoch{epoch}_{prompt}'
    draw_prob_graph(log, text=output_sent, filename=prompt, title=f'GPT entmax epoch{epoch}')

    return output_sent
Ejemplo n.º 3
0
 def attention(self, query, key, value, mask=None, dropout=None):
     "Compute 'Scaled Dot Product Attention'"
     d_k = query.size(-1)
     scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)
     if mask is not None:
         scores = scores.masked_fill(mask == 0, -1e9)
     p_attn = entmax_bisect(scores, alpha=self.alpha, dim=-1)
     if dropout is not None:
         p_attn = dropout(p_attn)
     return torch.matmul(p_attn, value), p_attn
Ejemplo n.º 4
0
 def __init__(
     self,
     plate_name: str,
     sampling_method: Optional[storch.sampling.SamplingMethod] = None,
     alpha: float = 1.5,
     adaptive=False,
     n_samples: int = 1,
     straight_through=False,
     initial_temperature=1.0,
     min_temperature=1.0e-4,
     annealing_rate=0.0,
 ):
     if not sampling_method:
         sampling_method = storch.sampling.MonteCarlo(plate_name, n_samples)
     super().__init__(
         plate_name, sampling_method.set_mc_sample(self.sample_gumbel_entmax),
     )
     self.adaptive = adaptive
     self.straight_through = straight_through
     self.register_buffer("temperature", torch.tensor(initial_temperature))
     self.register_buffer("annealing_rate", torch.tensor(annealing_rate))
     self.register_buffer("min_temperature", torch.tensor(min_temperature))
     self.alpha = alpha
     if adaptive:
         self.alpha = torch.nn.Parameter(
             torch.tensor(self.alpha, requires_grad=True)
         )
     if not adaptive and alpha == 1.5:
         self.entmax = entmax.entmax15
     elif not adaptive and alpha == 2.0:
         self.entmax = entmax.sparsemax
     else:
         if adaptive:
             self.entmax = lambda x: entmax.entmax_bisect(
                 x, torch.nn.functional.softplus(self.alpha - 1) + 1
             )
         else:
             self.entmax = lambda x: entmax.entmax_bisect(x, self.alpha)
Ejemplo n.º 5
0
    def forward(self, x):
        b, t, e = x.size()
        h = self.heads
        assert e == self.emb, f'Input embedding dim ({e}) should match layer embedding dim ({self.emb})'

        s = e // h
        x = x.view(b, t, h, s)

        keys = self.tokeys(x)
        queries = self.toqueries(x)
        values = self.tovalues(x)

        assert keys.size() == (b, t, h, s)
        assert queries.size() == (b, t, h, s)
        assert values.size() == (b, t, h, s)

        # Compute scaled dot-product self-attention

        # - fold heads into the batch dimension
        keys = keys.transpose(1, 2).contiguous().view(b * h, t, s)
        queries = queries.transpose(1, 2).contiguous().view(b * h, t, s)
        values = values.transpose(1, 2).contiguous().view(b * h, t, s)

        queries = queries / (e**(1 / 4))
        keys = keys / (e**(1 / 4))
        # - Instead of dividing the dot products by sqrt(e), we scale the keys and values.
        #   This should be more memory efficient

        # - get dot product of queries and keys, and scale
        dot = torch.bmm(queries, keys.transpose(1, 2))

        assert dot.size() == (b * h, t, t)

        if self.mask:  # mask out the upper half of the dot matrix, excluding the diagonal
            mask_(dot, maskval=float('-inf'), mask_diagonal=False)

        # dot = F.softmax(dot, dim=-1)
        # dot = sparsemax(dot, dim=-1)
        dot = entmax_bisect(dot, alpha=self.alpha, dim=-1)

        # - dot now has row-wise self-attention probabilities

        # apply the self attention to the values
        out = torch.bmm(dot, values).view(b, h, t, s)

        # swap h, t back, unify heads
        out = out.transpose(1, 2).contiguous().view(b, t, s * h)

        return self.unifyheads(out)
Ejemplo n.º 6
0
    def forward(self, scores: torch.Tensor,
                mask: torch.BoolTensor) -> torch.Tensor:
        """Map a score vector to a probability distribution akin to softmax (alpha=1) and sparsemax (alpha=2)

        Args:
            scores (torch.Tensor): (Batch x Sequence Length)
                Attention scores (also referred to as weights)
            mask (torch.BoolTensor): (Batch x Sequence Length)
                Specifies which indices are just padding

        Returns:
            torch.Tensor: Distribution resulting from entmax with specified alpha
        """
        # Entmax is only defined for alpha > 1
        self.alpha.data = torch.clamp(self.alpha.data, min=1.001)

        masked_scores = replace_masked_values(scores, mask, -float("inf"))
        return entmax_bisect(masked_scores, self.alpha, dim=-1)
def alpha_entmax_loss(model, batch, args):
    longer_sample = batch[0].to(args.gpu)
    inp = longer_sample[:, :args.train_batch_size]
    model_output = model(input_ids=inp)
    target = longer_sample[:, 1:args.train_batch_size + 1]
    logits = model_output[0]
    alpha = torch.tensor([args.alpha],
                         requires_grad=True,
                         device=torch.device(args.gpu))
    probs = entmax_bisect(logits, alpha)
    loss = ((probs - F.one_hot(target, num_classes=probs.size(-1))) *
            logits).sum(-1)
    loss += alpha_entropy(probs, args.alpha)
    loss = loss.sum()

    true_token_logits = -F.nll_loss(logits[0], target[0], reduction='none')
    ntokens = inp.numel()

    arange = np.arange(probs.size(1))
    next_token_probs = probs[:, arange, target.squeeze().tolist()]
    voc_sizes = probs.size(-1)
    smoothed_nll = -torch.mean(
        torch.log((next_token_probs + args.laplas_eps) /
                  (1 + args.laplas_eps * voc_sizes)))

    logging_output = TrainingMetrics.ranking_metrics(logits[0].float(),
                                                     true_token_logits, None,
                                                     ntokens, target[0])
    logging_output['loss'] = loss.item()
    logging_output['smoothed_nll_loss'] = smoothed_nll.item()
    logging_output['normalizer'] = ntokens
    logging_output['sample_size'] = ntokens
    logging_output['ntokens'] = ntokens
    logging_output['js_div'] = jensen_shannon_divergence(probs,
                                                         target).mean().item()
    print(logging_output['js_div'])

    loss = loss / ntokens

    return loss, logging_output
Ejemplo n.º 8
0
    def forward(self, hidden, orig_prob, attn, src_map):
        """
        Compute a distribution over the target dictionary
        extended by the dynamic dictionary implied by copying
        source words.
        Args:
           hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)``
           attn (FloatTensor): attn for each ``(batch x tlen, input_size)``
           src_map (FloatTensor):
               A sparse indicator matrix mapping each source word to
               its index in the "extended" vocab containing.
               ``(src_len, batch, extra_words)``
        """

        # CHECKS
        # batch_by_tlen, _ = hidden.size()
        # batch_by_tlen_, slen = attn.size()
        onehot_src_map = \
            F.one_hot(src_map.long(), torch.max(src_map).long() + 1)
        batch, slen, cvocab = onehot_src_map.size()

        if self.use_entmax:
            prob = entmax_bisect(orig_prob, 1.2)
        else:
            prob = torch.softmax(orig_prob, 1)

        # Probability of copying p(z=1) batch.
        p_copy = torch.sigmoid(self.linear_copy(hidden))
        # Probability of not copying: p_{word}(w) * (1 - p(z))
        out_prob = torch.mul(prob, 1 - p_copy)
        mul_attn = torch.mul(attn, p_copy)
        copy_prob = torch.bmm(
            mul_attn.view(batch, -1, slen),  # batch size x tgt len x src len
            onehot_src_map.float())  # batch size x src len x cvocab
        copy_prob = copy_prob.contiguous().view(-1, cvocab)
        return out_prob, copy_prob
Ejemplo n.º 9
0
 def forward(self):
     self.Y = entmax_bisect(self.X, self.alpha, dim=-1, n_iter=self.n_iter)
Ejemplo n.º 10
0
def log_entmax(*args, **kwargs):
    return torch.log(entmax_bisect(*args, **kwargs))
def eval_singletoken(model,
                     args,
                     dataset_paths,
                     config,
                     top_k=1,
                     top_p=0.0,
                     t=1.0,
                     train_iter=None,
                     batch_size=None):
    alpha_entmax = args.alpha_entmax

    batch_size = batch_size if batch_size is not None else args.batch_size_singletoken
    datasets = get_datasets(dataset_paths, max_len=batch_size)
    eval_sampler = SequentialSampler(datasets[args.eval_split])
    eval_dataloader = DataLoader(
        datasets[args.eval_split], sampler=eval_sampler, batch_size=1)

    model.eval()

    logging_outputs = []
    predicted_tokens = []
    target_tokens = []
    with torch.no_grad():
        for i, batch in tqdm(enumerate(eval_dataloader),
                             desc="Evaluating", total=len(eval_dataloader)):
            longer_sample = batch[0].to(args.gpu)
            inp = longer_sample[:, :args.batch_size_singletoken]
            model_output = model(input_ids=inp)
            target = longer_sample[:, 1:]
            logits = model_output[0]
            log_softmax_probs = F.log_softmax(logits, dim=-1)
            nll = F.nll_loss(log_softmax_probs[0], target[0], reduction='sum')
            true_token_logits = - \
                F.nll_loss(logits[0], target[0], reduction='none')

            if alpha_entmax is False:
                filtered_logits = top_k_top_p_filtering(
                    logits.squeeze(0), top_k=args.top_k, top_p=args.top_p).unsqueeze(0)
                prev = F.softmax(
                    filtered_logits.view(filtered_logits.shape[1:]),
                    dim=-1).multinomial(num_samples=1).unsqueeze(0).squeeze(-1)
                probs = F.softmax(filtered_logits, dim=-1)
            else:
                probs = entmax_bisect(logits, torch.tensor(
                    [args.alpha], requires_grad=True, device=torch.device(args.gpu)).float())
            arange = np.arange(logits.size(1))

            next_token_probs = probs[:, arange, target.squeeze().tolist()]
            voc_sizes = probs.size(-1)
            smoothed_nll = -torch.mean(torch.log(
                (next_token_probs + args.laplas_eps) / (1 + args.laplas_eps * voc_sizes)
            ))

            pred = probs.view(-1, probs.size(-1)
                              ).multinomial(num_samples=1).view(probs.shape[:-1])
            predicted_tokens.extend(pred.view(-1).tolist())
            ntokens = inp.numel()

            rep_logits = torch.zeros_like(logits)
            rep_logits[:, arange, pred.squeeze().tolist()] = 1
            logging_output = TrainingMetrics.ranking_metrics(
                rep_logits[0].float(), true_token_logits, None, ntokens, target[0])
            logging_output['loss'] = nll.item()
            logging_output['smoothed_nll_loss'] = smoothed_nll.item()
            logging_output['normalizer'] = ntokens
            logging_output['sample_size'] = ntokens
            logging_output['ntokens'] = ntokens
            logging_output['js_div'] = jensen_shannon_divergence(
                probs, target).mean().item()
            if args.token_loss == 'alpha_entmax':
                loss = ((probs - F.one_hot(target,
                                           num_classes=probs.size(-1))) * logits).sum(-1)
                loss += alpha_entropy(probs, args.alpha)
                logging_output['alpha_entmax_loss'] = loss.mean().item()
            logging_outputs.append(logging_output)

            # for human uniq
            target_tokens.extend(target.view(-1).tolist())

    logging_average = CrossEntropyCriterionWCustomMetrics.aggregate_logging_outputs(
        logging_outputs)
    logging_average['e_ppl'] = np.exp(
        np.mean([x['smoothed_nll_loss'] for x in logging_outputs]))
    # aggregate_logging_outputs does division by log(2) of loss
    logging_average['ppl'] = 2**logging_average['loss']
    logging_average['human_uniq'] = len(set(target_tokens))
    logging_average['uniq'] = len(set(predicted_tokens))
    logging_average['wrep'] = np.mean(
        [v for k, v in logging_average.items() if k.startswith('wrong_repeat')])
    logging_average['rep'] = np.mean(
        [v for k, v in logging_average.items() if k.startswith('repeat')])
    logging_average['js_div'] = np.mean([x['js_div'] for x in logging_outputs])
    if args.token_loss == 'alpha_entmax':
        logging_average['alpha_entmax_loss'] = np.mean(
            [x['alpha_entmax_loss'] for x in logging_outputs])

    save_singletoken_sampling_metrics(
        logging_average,
        config.to_dict(),
        args,
        top_k=top_k,
        top_p=top_p,
        train_iter=train_iter)

    return logging_average
def sample_sequence(model,
                    prefix_batch,
                    prefix_length,
                    continuation_length,
                    num_samples=1,
                    top_k=0,
                    top_p=0.0,
                    temperature=1.0,
                    alpha_entmax=False,
                    output_prefix_hidden=False,
                    repetition_penalty=1.0, **kwargs):
    continuation_logits = []
    context = prefix_batch
    context = torch.cat([context] * num_samples, 0)
    assert context.size(1) == prefix_length

    prev = context
    output = context
    past = None

    log_probs = torch.zeros(
        (num_samples *
         prefix_batch.size(0),
         continuation_length))

    policy_pis = []

    for i in range(continuation_length):
        logits, past = model(input_ids=prev, past=past)[:2]
        if i == 0 and output_prefix_hidden:
            prefix_hidden = out[2]

        logits = logits[:, -1, :]
        logits = logits / temperature

        if repetition_penalty != 1.0:
            for ex_id, pert_logits in enumerate(logits):
                for token_idx in set(output[ex_id].tolist()):
                    if pert_logits[token_idx] < 0:
                        pert_logits[token_idx] *= repetition_penalty
                    else:
                        pert_logits[token_idx] /= repetition_penalty
        if alpha_entmax is False:
            if top_k == 1 and top_p == 0:
                filtered_logits = logits
                prev = logits.float().argmax(dim=1, keepdim=True)
            else:
                filtered_logits = top_k_top_p_filtering(
                    logits, top_k=top_k, top_p=top_p)
                prev = F.softmax(
                    filtered_logits,
                    dim=-
                    1).multinomial(
                    num_samples=1)

            #log_prob = F.log_softmax(filtered_logits, dim=-1)
            log_prob = F.log_softmax(logits, dim=-1)
        else:
            alpha = kwargs.get('alpha', 1.0)
            prob = entmax_bisect(
                logits,
                torch.tensor(
                    [alpha],
                    requires_grad=True,
                    device=logits.device).float())
            log_prob = torch.log(prob)
            prev = prob.multinomial(num_samples=1)
            filtered_logits = logits

        continuation_logits.append(logits)
        output = torch.cat((output, prev), dim=1)

        arange = np.arange(filtered_logits.size(0))
        next_token_logit = filtered_logits[arange,
                                           prev.squeeze().tolist()].squeeze()

        next_token_log_prob = log_prob[arange,
                                       prev.squeeze().tolist()].squeeze()
        log_probs[:, i] = next_token_log_prob
        policy_pis.append(log_prob.squeeze())

    policy_pis = torch.stack(policy_pis, 1)

    continuation_logits = torch.stack(continuation_logits, 1)
    if output_prefix_hidden:
        result = (
            output,
            log_probs,
            continuation_logits,
            policy_pis,
            prefix_hidden)
    else:
        result = (output, log_probs, continuation_logits, policy_pis)
    return result
Ejemplo n.º 13
0
    def _generate_beam(self,
                       src_enc,
                       src_mask,
                       beam_size,
                       length_penalty=0.0,
                       early_stopping=False,
                       min_len=0,
                       max_len=200,
                       trigram_blocking=False,
                       return_all=False,
                       src_map=None,
                       src_tgt_vocab_map=None):
        """
        Decode a sentence given initial start.
        `x`:
            - LongTensor(bs, slen)
                <EOS> W1 W2 W3 <EOS> <PAD>
                <EOS> W1 W2 W3   W4  <EOS>
        `lengths`:
            - LongTensor(bs) [5, 6]
        `positions`:
            - False, for regular "arange" positions (LM)
            - True, to reset positions from the new generation (MT)
        `langs`:
            - must be None if the model only supports one language
            - lang_id if only one language is involved (LM)
            - (lang_id1, lang_id2) if two languages are involved (MT)
        """

        # check inputs
        assert src_enc.size(0) == src_mask.size(0)
        assert beam_size >= 1

        # batch size / number of words
        bs = len(src_mask)
        n_words = self.n_words if not self.use_copy else self.n_words + src_tgt_vocab_map.shape[
            1]

        # expand to beam size the source latent representations / source lengths
        src_enc = src_enc.unsqueeze(
            1).expand((bs, beam_size) +
                      src_enc.shape[1:]).contiguous().view((bs * beam_size, ) +
                                                           src_enc.shape[1:])
        src_mask = src_mask.unsqueeze(1).expand(
            (bs, beam_size) +
            src_mask.shape[1:]).contiguous().view((bs * beam_size, ) +
                                                  src_mask.shape[1:])
        if src_tgt_vocab_map is not None:
            src_tgt_vocab_map = src_tgt_vocab_map.unsqueeze(
                1).expand((bs, beam_size) +
                          src_tgt_vocab_map.shape[1:]).contiguous().view(
                              (bs * beam_size, ) + src_tgt_vocab_map.shape[1:])
        if src_map is not None:
            src_map = src_map.unsqueeze(1).expand(
                (bs, beam_size) +
                src_map.shape[1:]).contiguous().view((bs * beam_size, ) +
                                                     src_map.shape[1:])
        # src_len = src_len.unsqueeze(1).expand(bs, beam_size).contiguous().view(-1)

        # generated sentences (batch with beam current hypotheses)
        generated = src_enc.new(bs * beam_size, max_len)  # upcoming output
        generated.fill_(self.pad_index)  # fill upcoming ouput with <PAD>
        generated[:,
                  0].fill_(self.bos_index)  # we use <EOS> for <BOS> everywhere

        # generated hypotheses
        generated_hyps = [
            BeamHypotheses(beam_size, max_len, length_penalty, early_stopping)
            for _ in range(bs)
        ]
        trigram_set = [set() for _ in range(bs * beam_size)]

        # scores for each sentence in the beam
        beam_scores = src_enc.new(bs, beam_size).fill_(0)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view(-1)

        # current position
        cur_len = 1

        # cache compute states
        cache = {'slen': 0}

        # done sentences
        done = [False for _ in range(bs)]

        while cur_len < max_len:

            # compute word scores
            tensor, _ = self.fwd(
                x=generated[:, :cur_len]
                if not self.use_copy else generated[:, :cur_len].masked_fill(
                    generated[:, :cur_len].gt(self.n_words - 1), 0),
                src_enc=src_enc,
                src_mask=src_mask,
                cache=cache,
                src_map=src_map)
            if self.use_copy:
                tensor = torch.cat(tensor, 1)
                scores, _ = model_utils.collapse_copy_scores(
                    scores=tensor,
                    src_tgt_vocab_map=src_tgt_vocab_map,
                    vocab_size=self.n_words)
                scores[:, self.n_words] = 0
            else:
                assert tensor.size() == (bs * beam_size, 1, self.n_words)
                scores = tensor[:, -1, :]  # (bs * beam_size, dim)

            scores[:, 0] = -float('Inf') if not self.use_copy else 0
            scores[:,
                   self.pad_index] = -float('Inf') if not self.use_copy else 0
            scores[:,
                   self.bos_index] = -float('Inf') if not self.use_copy else 0

            if cur_len < min_len:
                scores[:, self.
                       eos_index] = -float('Inf') if not self.use_copy else 0

            if self.use_copy:
                scores = (scores + 1e-10).log()
            elif self.use_entmax:
                scores = torch.log(entmax_bisect(scores, 1.2) + 1e-10)
            else:
                scores = F.log_softmax(scores,
                                       dim=-1)  # (bs * beam_size, n_words)

            assert scores.size() == (bs * beam_size,
                                     n_words), (scores.shape, (bs * beam_size,
                                                               n_words))

            # select next words with scores
            _scores = scores + beam_scores[:, None].expand_as(
                scores)  # (bs * beam_size, n_words)
            _scores = _scores.view(bs, beam_size *
                                   n_words)  # (bs, beam_size * n_words)

            next_scores, next_words = torch.sort(_scores,
                                                 dim=1,
                                                 descending=True)
            assert next_scores.size() == next_words.size() == (bs, n_words *
                                                               beam_size)

            # next batch beam content
            # list of (bs * beam_size) tuple(next hypothesis score, next word, current position in the batch)
            next_batch_beam = []

            # for each sentence
            for sent_id in range(bs):

                # if we are done with this sentence
                done[sent_id] = done[sent_id] or generated_hyps[
                    sent_id].is_done(next_scores[sent_id].max().item())
                if done[sent_id]:
                    next_batch_beam.extend([(0, self.pad_index, 0)] *
                                           beam_size)  # pad the batch
                    continue

                # next sentence beam content
                next_sent_beam = []
                n_add = 0

                # next words for this sentence
                for idx, value in zip(next_words[sent_id],
                                      next_scores[sent_id]):

                    # get beam and word IDs
                    beam_id = idx // n_words
                    word_id = idx % n_words

                    if trigram_blocking and cur_len > 2:
                        trigram = tuple(generated[sent_id * beam_size +
                                                  beam_id, cur_len -
                                                  2:cur_len].tolist() +
                                        [word_id.item()])
                        if trigram in trigram_set[sent_id * beam_size +
                                                  beam_id]:
                            continue
                    # end of sentence, or next word
                    if word_id == self.eos_index or cur_len + 1 == max_len:
                        n_add += 1
                        generated_hyps[sent_id].add(
                            generated[sent_id * beam_size +
                                      beam_id, :cur_len].clone(), value.item())
                    else:
                        next_sent_beam.append(
                            (value, word_id, sent_id * beam_size + beam_id))
                        if trigram_blocking and cur_len > 2:
                            trigram_set[sent_id * beam_size +
                                        beam_id].add(trigram)

                    # the beam for next step is full
                    if len(next_sent_beam) == beam_size or (
                            cur_len + 1 == max_len and n_add == beam_size):
                        break

                # update next beam content
                assert len(next_sent_beam
                           ) == 0 if cur_len + 1 == max_len else beam_size
                if len(next_sent_beam) == 0:
                    next_sent_beam = [(0, self.pad_index, 0)
                                      ] * beam_size  # pad the batch
                next_batch_beam.extend(next_sent_beam)
                assert len(next_batch_beam) == beam_size * (sent_id + 1)

            # sanity check / prepare next batch
            assert len(next_batch_beam) == bs * beam_size
            beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
            beam_words = generated.new([x[1] for x in next_batch_beam])
            beam_idx = generated.new([x[2] for x in next_batch_beam]).long()

            # re-order batch and internal states
            trigram_set = [
                deepcopy(trigram_set[x[2]]) for x in next_batch_beam
            ]
            generated = generated[beam_idx, :]
            generated[:, cur_len] = beam_words
            for k in cache.keys():
                if k != 'slen':
                    cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])

            # update current length
            cur_len = cur_len + 1

            # stop when we are done with each sentence
            if all(done):
                break

        if return_all:
            return generated_hyps

        # select the best hypotheses
        tgt_len = src_enc.new(bs).long()
        best = []
        best_scores = []

        for i, hypotheses in enumerate(generated_hyps):
            best_score, best_hyp = max(hypotheses.hyp, key=lambda x: x[0])
            tgt_len[i] = len(best_hyp) + 1  # +1 for the <EOS> symbol
            best.append(best_hyp)
            best_scores.append(best_score)

        # generate target batch
        decoded = src_enc.new(tgt_len.max().item(), bs).fill_(self.pad_index)
        for i, hypo in enumerate(best):
            decoded[:tgt_len[i] - 1, i] = hypo
            decoded[tgt_len[i] - 1, i] = self.eos_index

        # sanity check
        assert (decoded == self.eos_index).sum() == bs

        return decoded.transpose(
            0, 1).cpu().numpy(), best_scores, tgt_len.cpu().numpy()
Ejemplo n.º 14
0
    def _generate(self,
                  src_enc,
                  src_mask,
                  max_len=200,
                  min_len=0,
                  top_p=None,
                  src_map=None,
                  src_tgt_vocab_map=None):
        """
        Decode a sentence given initial start.
        `x`:
            - LongTensor(bs, slen)
                <EOS> W1 W2 W3 <EOS> <PAD>
                <EOS> W1 W2 W3   W4  <EOS>
        `lengths`:
            - LongTensor(bs) [5, 6]
        `positions`:
            - False, for regular "arange" positions (LM)
            - True, to reset positions from the new generation (MT)
        `langs`:
            - must be None if the model only supports one language
            - lang_id if only one language is involved (LM)
            - (lang_id1, lang_id2) if two languages are involved (MT)
        """

        # input batch
        bs = len(src_mask)
        assert src_enc.size(0) == bs

        # generated sentences
        generated = src_mask.new(bs, max_len)  # upcoming output
        generated.fill_(self.pad_index)  # fill upcoming ouput with <PAD>
        generated[:,
                  0].fill_(self.bos_index)  # we use <EOS> for <BOS> everywhere

        # current position / max lengths / length of generated sentences / unfinished sentences
        cur_len = 1
        # gen_len = torch.ones(bs).to(src_mask.device).long()
        unfinished_sents = torch.ones(bs).to(
            src_mask.device).long()  #src_len.clone().fill_(1)
        all_scores = torch.zeros(bs).to(src_mask.device)
        # cache compute states
        cache = {'slen': 0}

        while cur_len < max_len:

            # compute word scores
            tensor, _ = self.fwd(
                x=generated[:, :cur_len]
                if not self.use_copy else generated[:, :cur_len].masked_fill(
                    generated[:, :cur_len].gt(self.n_words - 1), 0),
                src_enc=src_enc,
                src_mask=src_mask,
                cache=cache,
                src_map=src_map)
            if self.use_copy:
                tensor = torch.cat(tensor, 1)
                scores, _ = model_utils.collapse_copy_scores(
                    scores=tensor,
                    src_tgt_vocab_map=src_tgt_vocab_map,
                    vocab_size=self.n_words)
                scores[:, self.n_words] = 0.0
            else:
                assert tensor.size() == (bs, 1,
                                         self.n_words), (cur_len, max_len,
                                                         src_enc.size(),
                                                         tensor.size(),
                                                         (1, bs, self.n_words))
                scores = tensor[:, -1, :]  # (bs, dim)
            # scores = self.pred_layer.get_scores(tensor)      # (bs, n_words)

            scores[:, 0] = -float('Inf') if not self.use_copy else 0
            scores[:,
                   self.pad_index] = -float('Inf') if not self.use_copy else 0
            scores[:,
                   self.bos_index] = -float('Inf') if not self.use_copy else 0

            if cur_len < min_len:
                scores[:, self.
                       eos_index] = -float('Inf') if not self.use_copy else 0

            # select next words: sample or greedy
            if top_p:
                if self.use_copy:
                    next_words = torch.multinomial(
                        model_utils.top_k_top_p_filtering(
                            scores,
                            top_k=0.0,
                            top_p=top_p if top_p else 0.0,
                            filter_value=0.0,
                            need_softmax=False), 1).squeeze(1)
                    next_scores = (scores + 1e-10).log().gather(
                        1, next_words.unsqueeze(1)).squeeze(1)
                else:
                    next_words = torch.multinomial(
                        F.softmax(model_utils.top_k_top_p_filtering(
                            scores, top_k=0.0, top_p=top_p if top_p else 0.0),
                                  dim=1), 1).squeeze(1)
                    next_scores = scores.log_softmax(1).gather(
                        1, next_words.unsqueeze(1)).squeeze(1)
            else:
                if self.use_copy:
                    next_scores, next_words = (scores + 1e-10).log().max(1)
                elif self.use_entmax:
                    next_scores, next_words = (entmax_bisect(scores, 1.2) +
                                               1e-10).log().max(1)
                else:
                    next_scores, next_words = scores.log_softmax(1).max(1)
            assert next_words.size() == (bs, )

            # update generations / lengths / finished sentences / current length
            generated[:,
                      cur_len] = next_words * unfinished_sents + self.pad_index * (
                          1 - unfinished_sents)
            all_scores = all_scores + next_scores * unfinished_sents.float()
            # gen_len.add_(unfinished_sents)
            unfinished_sents.mul_(next_words.ne(self.eos_index).long())
            cur_len = cur_len + 1

            # stop when there is a </s> in each sentence, or if we exceed the maximul length
            if unfinished_sents.max() == 0:
                break

        # add <EOS> to unfinished sentences
        if cur_len == max_len:
            generated[:, -1].masked_fill_(unfinished_sents.bool(),
                                          self.eos_index)

        # sanity check
        assert (generated == self.eos_index).sum() == bs

        return generated[:, 1:cur_len].cpu().numpy(), all_scores.cpu().numpy(
        )  #, gen_len
Ejemplo n.º 15
0
    def forward(
        self,
        query,
        key,
        value,
        key_padding_mask=None,
        incremental_state=None,
        need_weights=True,
        static_kv=False,
        attn_mask=None,
        before_softmax=False,
        need_head_weights=False,
    ):
        """Input shape: Time x Batch x Channel

        Args:
            key_padding_mask (ByteTensor, optional): mask to exclude
                keys that are pads, of shape `(batch, src_len)`, where
                padding elements are indicated by 1s.
            need_weights (bool, optional): return the attention weights,
                averaged over heads (default: False).
            attn_mask (ByteTensor, optional): typically used to
                implement causal attention, where the mask prevents the
                attention from looking forward in time (default: None).
            before_softmax (bool, optional): return the raw attention
                weights and values before the attention softmax.
            need_head_weights (bool, optional): return the attention
                weights for each head. Implies *need_weights*. Default:
                return the average attention weights over all heads.
        """
        if need_head_weights:
            need_weights = True

        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]

        if self.enable_torch_version and not self.onnx_trace and incremental_state is None and not static_kv:
            return F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                torch.empty([0]),
                torch.cat(
                    (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                self.training,
                key_padding_mask,
                need_weights,
                attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj.weight,
                k_proj_weight=self.k_proj.weight,
                v_proj_weight=self.v_proj.weight)

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if 'prev_key' in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert self.encoder_decoder_attention and not self.self_attention
                    key = value = None
        else:
            saved_state = None

        if self.self_attention:
            q = self.q_proj(query)
            k = self.k_proj(query)
            v = self.v_proj(query)
        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.q_proj(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k = self.k_proj(key)
                v = self.v_proj(key)

        else:
            q = self.q_proj(query)
            k = self.k_proj(key)
            v = self.v_proj(value)
        q *= self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat([
                    key_padding_mask,
                    key_padding_mask.new_zeros(key_padding_mask.size(0), 1)
                ],
                                             dim=1)

        q = q.contiguous().view(tgt_len, bsz * self.num_heads,
                                self.head_dim).transpose(0, 1)
        if k is not None:
            k = k.contiguous().view(-1, bsz * self.num_heads,
                                    self.head_dim).transpose(0, 1)
        if v is not None:
            v = v.contiguous().view(-1, bsz * self.num_heads,
                                    self.head_dim).transpose(0, 1)

        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if 'prev_key' in saved_state:
                prev_key = saved_state['prev_key'].view(
                    bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    k = torch.cat((prev_key, k), dim=1)
            if 'prev_value' in saved_state:
                prev_value = saved_state['prev_value'].view(
                    bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    v = torch.cat((prev_value, v), dim=1)
            key_padding_mask = self._append_prev_key_padding_mask(
                key_padding_mask=key_padding_mask,
                prev_key_padding_mask=saved_state.get('prev_key_padding_mask',
                                                      None),
                batch_size=bsz,
                src_len=k.size(1),
                static_kv=static_kv,
            )

            saved_state['prev_key'] = k.view(bsz, self.num_heads, -1,
                                             self.head_dim)
            saved_state['prev_value'] = v.view(bsz, self.num_heads, -1,
                                               self.head_dim)
            saved_state['prev_key_padding_mask'] = key_padding_mask

            self._set_input_buffer(incremental_state, saved_state)

        src_len = k.size(1)

        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.shape == torch.Size(
            []):
            key_padding_mask = None

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])],
                          dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])],
                          dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat([
                    key_padding_mask,
                    torch.zeros(key_padding_mask.size(0),
                                1).type_as(key_padding_mask)
                ],
                                             dim=1)
        if not bmm_fp16_support:
            q = q.float()
            k = k.float()
            v = v.float()
        attn_weights = torch.bmm(q, k.transpose(1, 2))
        if not bmm_fp16_support:
            attn_weights = attn_weights.type_as(query)
        attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len,
                                              bsz)

        assert list(
            attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            if self.onnx_trace:
                attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
            attn_weights += attn_mask

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len)
            attn_weights = attn_weights.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2),
                float('-inf'),
            )
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
                                             src_len)

        if before_softmax:
            return attn_weights, v
        # 1
        if not self.cur_san_active:
            self.div = 0
        if self.div > 0:
            top_k = int(torch.ceil(torch.Tensor([src_len / self.div])))
            if top_k < self.lb:
                top_k = self.lb
                if top_k > src_len:
                    top_k = src_len
        else:
            top_k = -self.div
            if top_k > src_len:
                top_k = src_len
        # 2
        # print('attn_weights ', attn_weights.size())
        if self.entmax:
            from entmax import sparsemax, entmax15, entmax_bisect
            if self.entmax == 1:
                attn_weights = sparsemax(attn_weights.float(),
                                         dim=-1).type_as(attn_weights)
            elif self.entmax == 2:
                attn_weights = entmax15(attn_weights.float(),
                                        dim=-1).type_as(attn_weights)
            elif self.entmax == 3:
                attn_weights_float = entmax_bisect(
                    attn_weights.float(), dim=-1).type_as(attn_weights)
        else:
            if self.div:
                vk, _ = torch.topk(attn_weights, top_k)
                # print(value)
                tk = vk[:, :, -1].unsqueeze(2).expand_as(attn_weights)
                mask_k = torch.lt(attn_weights, tk)
                attn_weights = attn_weights.masked_fill(
                    mask_k, float('-inf')).type_as(attn_weights)
            attn_weights_float = utils.softmax(attn_weights,
                                               dim=-1,
                                               onnx_trace=self.onnx_trace)
        attn_weights = attn_weights_float.type_as(attn_weights)
        attn_probs = F.dropout(attn_weights_float.type_as(attn_weights),
                               p=self.dropout,
                               training=self.training)
        if not bmm_fp16_support:
            attn_probs = attn_probs.float(
            )  # bsz * self.num_heads, tgt_len, src_len
        attn = torch.bmm(attn_probs, v)
        if not bmm_fp16_support:
            attn = attn.type_as(query)
        assert list(
            attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        if (self.onnx_trace and attn.size(1) == 1):
            # when ONNX tracing a single decoder step (sequence length == 1)
            # the transpose is a no-op copy before view, thus unnecessary
            attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
        else:
            attn = attn.transpose(0,
                                  1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)

        if need_weights:
            attn_weights = attn_weights_float.view(bsz, self.num_heads,
                                                   tgt_len,
                                                   src_len).transpose(1, 0)
            if not need_head_weights:
                # average attention weights over heads
                attn_weights = attn_weights.mean(dim=0)
        else:
            attn_weights = None

        return attn, attn_weights