Exemplo n.º 1
0
def attention(query, key, value, params, mask=None, dropout=None, alpha=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)
    if mask is not None:
        try:
            scores = scores.masked_fill(mask == 0, -1e9)
        except:
            embed()

    if params.attn_type=='softmax':
        p_attn = F.softmax(scores, dim = -1)
    elif params.attn_type=='sparsemax':
        p_attn = sparsemax(scores, dim=-1)
    elif params.attn_type=='entmax15':
        p_attn = entmax15(scores, dim=-1)
    elif params.attn_type=='entmax':
        p_attn = EntmaxBisect(scores, alpha, n_iter=25)
    else:
        raise Exception
    if dropout is not None:
        p_attn = dropout(p_attn)
    p_attn = p_attn.to(torch.float32)
    return torch.matmul(p_attn, value), scores, p_attn
Exemplo n.º 2
0
    def forward(self, query, key, mask=None):
        # query and value are two copies of sentence representation H
        # query: [nbatches, seq_len, d_model]
        # value: [nbatches, seq_len, d_model]
        # mask: [nbatches, seq_len, seq_len]
        nbatches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query = self.w_q(query).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        key = self.w_k(key).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        # [nbatches, h, seq_len, d_k]

        # 2) Apply attention on all the projected vectors in batch.
        # Compute 'Scaled Dot Product Attention'
        scores = torch.matmul(query, key.transpose(-2, -1)) \
                 / math.sqrt(self.d_k)  # [nbatches, h, seq_len, seq_len]
        if mask is not None:
            key_padding_mask = mask.unsqueeze(1).unsqueeze(2)
            scores = scores.masked_fill(key_padding_mask == 0, float("-inf"))
        p_attn = entmax15(scores, dim=-1)  # [nbatches, h, seq_len, seq_len]
        if self.dropout is not None:
            p_attn = self.dropout(p_attn)

        # 3) "Concat" using a view and apply a final linear.
        p_attn = torch.sum(p_attn, dim=1) / self.h
        return p_attn  # [nbatches, seq_len, seq_len]
Exemplo n.º 3
0
    def forward(self, scores: torch.Tensor,
                mask: torch.BoolTensor) -> torch.Tensor:
        """Map a score vector to a probability distribution halfway between softmax and sparsemax

        Args:
            scores (torch.Tensor): (Batch x Sequence Length)
                Attention scores (also referred to as weights)
            mask (torch.BoolTensor): (Batch x Sequence Length)
                Specifies which indices are just padding

        Returns:
            torch.Tensor: Distribution halfway between softmax and sparsemax
        """
        masked_scores = replace_masked_values(scores, mask, -float("inf"))
        return entmax15(masked_scores, dim=-1)
Exemplo n.º 4
0
def sparse_attention(query, key, value, alpha, mask=None, dropout=None):
    "Use sparse activation function"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    if alpha == 2:
        p_attn = entmax.sparsemax(scores, -1)
    elif alpha == 1.5:
        p_attn = entmax.entmax15(scores, -1)
    else:
        raise NotImplementedError
    if dropout is not None:
        p_attn = dropout(p_attn)
    # return torch.matmul(p_attn, value), scores.squeeze(1).squeeze(1)
    return torch.matmul(p_attn, value), p_attn
Exemplo n.º 5
0
def nce_loss_entmax(positive, negatives, temperature):
    """

    :param positive: b * k
    :param negatives: b * k * num_negatives
    :return:
    """
    negatives_and_positive = torch.cat([positive.unsqueeze(2), negatives],
                                       dim=2)
    from entmax import entmax15
    entmax = entmax15(negatives_and_positive, dim=2)

    loss_batch = torch.log(entmax[:, :, 0] + 1e-8)

    # sum over k, mean over batches
    loss = -loss_batch.sum(1).mean(0)
    # loss = -torch.mean(loss_batch)
    return loss
Exemplo n.º 6
0
    def forward(self, char_encoder_result, tag_encoder_result, true_output_seq=None):
        char_encoding, (char_hn, char_cn) = char_encoder_result
        batch_size = len(char_encoding)
        tag_encoding, (tag_hn, tag_cn) = tag_encoder_result
        char_encoding = torch.transpose(char_encoding, 0, 1)  # move seq_len dimension to the front
        tag_encoding = torch.transpose(tag_encoding, 0, 1)
        current_input = torch.zeros((batch_size, self.n_chars), device=char_encoding.device)
        last_cell_state = (
            #torch.cat((char_hn[0], tag_hn[0]), dim=-1),
            #torch.cat((char_cn[0], tag_cn[0]), dim=-1)
            torch.cat((char_hn[0], char_hn[1]), dim=-1),
            torch.cat((char_cn[0], char_cn[1]), dim=-1)
        )

        def time_step_fn(input_1, state_0):
            h1, c1 = self.lstm_cell(input_1, state_0)
            query = torch.unsqueeze(h1, dim=0)  # use cell output as query
            char_attention, _ = self.char_attention(query=query, key=char_encoding, value=char_encoding)
            tag_attention, _ = self.tag_attention(query=query, key=tag_encoding, value=tag_encoding)
            aggregated_attention = torch.cat([char_attention, tag_attention], dim=-1).squeeze(0)
            aggregated_attention = torch.relu(aggregated_attention)
            output = self.output_layer(aggregated_attention)  # relu instead?
            return output, (h1, c1)

        top = [[current_input, last_cell_state, [], 0]]  # beam search candidates; last entry is log probability
        teacher_forcing = true_output_seq is not None
        for time_step in range(len(char_encoding)):
            time_step_leaders = []
            for candidate in top:
                next_input, current_cell_state, current_output_seq, sequence_probability = candidate
                candidate_output, candidate_next_state = time_step_fn(next_input, current_cell_state)
                if teacher_forcing:  # teacher forcing; in this scenario, top only has 1 item
                    top = [[None, candidate_next_state, current_output_seq + [candidate_output], 1]]
                    top[0][0] = true_output_seq[:, time_step, :]
                    continue
                else:
                    probabilities = entmax15(candidate_output, dim=-1)
                    tk = torch.topk(probabilities, self.beam_size, dim=-1)
                    top_indices = tk.indices[0]
                    top_probs = tk.values[0]
                    for i in range(self.beam_size):
                        time_step_leaders.append(
                            [top_indices[i], top_probs[i], candidate_next_state, current_output_seq,
                             sequence_probability + torch.log(top_probs[i])]
                        )
            if not teacher_forcing:
                new_top = []
                time_step_leaders.sort(key=lambda x: x[4])
                beam_size = self.beam_size
                if time_step == self.beam_size - 1:
                    beam_size = 1
                for leader in time_step_leaders[-beam_size:]:
                    leader_index, leader_prob, leader_next_state, leader_current_output_seq, probability = leader
                    one_hot = torch.nn.functional.one_hot(leader_index, num_classes=self.n_chars)
                    one_hot = torch.unsqueeze(one_hot, dim=0).float()  # add batch dimension
                    new_top.append([one_hot, leader_next_state, leader_current_output_seq + [one_hot], probability])
                top = new_top

        return_sequence = top[0][2]
        return_sequence = torch.stack(return_sequence)
        return torch.transpose(return_sequence, 0, 1)
def eval_semisuper_vae(vae,
                       classifier,
                       loader_unlabeled,
                       super_loss,
                       loader_labeled=[None],
                       train=False,
                       optimizer=None,
                       topk=0,
                       grad_estimator=bs_lib.reinforce,
                       grad_estimator_kwargs={'grad_estimator_kwargs': None},
                       n_samples=1,
                       train_labeled_only=False,
                       epoch=0,
                       baseline_optimizer=None,
                       normalizer='softmax'):

    if train:
        assert optimizer is not None
        vae.train()
        classifier.train()

    else:
        vae.eval()
        classifier.eval()

    sum_loss = 0.0
    num_images = 0.0
    total_nz = 0.0

    for labeled_data, unlabeled_data in zip(cycle(loader_labeled), \
                                                loader_unlabeled):

        unlabeled_image = unlabeled_data['image'].to(device)

        if labeled_data is not None:
            labeled_image = labeled_data['image'].to(device)
            true_labels = labeled_data['label'].to(device)

            # get loss on labeled images
            supervised_loss = \
                get_supervised_loss(vae, classifier, labeled_image, true_labels, super_loss).sum()

            num_labeled = len(loader_labeled.sampler)
            num_labeled_batch = labeled_image.shape[0]

        else:
            supervised_loss = 0.0
            num_labeled = 0.0
            num_labeled_batch = 1.0

        # run through classifier
        scores = classifier.forward(unlabeled_image)

        if normalizer == 'softmax':
            class_weights = torch.softmax(scores, dim=-1)
        elif normalizer == 'entmax15':
            class_weights = entmax15(scores, dim=-1)
        elif normalizer == 'sparsemax':
            class_weights = sparsemax(scores, dim=-1)
        else:
            raise NameError("%s is not a valid normalizer!" % (normalizer, ))

        # get a mask of nonzeros
        nz = (class_weights > 0).to(class_weights.device)

        if train:

            train_labeled_only_bool = 1.
            if train_labeled_only:
                n_samples = 0
                train_labeled_only_bool = 0.

            # flush gradients
            optimizer.zero_grad()

            # get unlabeled pseudoloss: here we use our
            # Rao-Blackwellization or some other gradient estimator
            f_z = lambda z: vae_utils.get_loss_from_one_hot_label(
                vae, unlabeled_image, z)
            unlabeled_ps_loss = 0.0
            for i in range(n_samples):
                unlabeled_ps_loss_ = rb_lib.get_raoblackwell_ps_loss(
                    f_z,
                    class_weights,
                    topk=topk,
                    epoch=epoch,
                    data=unlabeled_image,
                    grad_estimator=grad_estimator,
                    grad_estimator_kwargs=grad_estimator_kwargs)

                unlabeled_ps_loss += unlabeled_ps_loss_

            unlabeled_ps_loss = unlabeled_ps_loss / max(n_samples, 1)

            kl_q = torch.sum(class_weights[nz] * torch.log(class_weights[nz]))

            total_ps_loss = \
                (unlabeled_ps_loss + kl_q) * train_labeled_only_bool * \
                len(loader_unlabeled.sampler) / unlabeled_image.shape[0] + \
                supervised_loss * num_labeled / labeled_image.shape[0]

            # backprop gradients from pseudo loss
            total_ps_loss.backward(retain_graph=True)
            optimizer.step()

            if baseline_optimizer is not None:
                # for RELAX: as it trains to minimize a control variate
                # flush gradients
                optimizer.zero_grad()
                # for params in classifier.parameters():
                baseline_optimizer.zero_grad()
                loss_grads = grad(total_ps_loss,
                                  classifier.parameters(),
                                  create_graph=True)
                gn2 = sum([grd.norm()**2 for grd in loss_grads])
                gn2.backward()
                baseline_optimizer.step()

        # loss at MAP value of z
        loss = \
            vae_utils.get_labeled_loss(vae, unlabeled_image,
                                torch.argmax(scores, dim = 1)).detach().sum()

        sum_loss += loss
        num_images += unlabeled_image.shape[0]

        total_nz += nz.sum().item()

    return sum_loss / num_images, total_nz / num_images
Exemplo n.º 8
0
    def forward(
        self,
        query,
        key: Optional[Tensor],
        value: Optional[Tensor],
        key_padding_mask: Optional[Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        need_weights: bool = True,
        static_kv: bool = False,
        attn_mask: Optional[Tensor] = None,
        before_softmax: bool = False,
        need_head_weights: bool = False,
        prune_attn_mask = None,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        """Input shape: Time x Batch x Channel

        Args:
            key_padding_mask (ByteTensor, optional): mask to exclude
                keys that are pads, of shape `(batch, src_len)`, where
                padding elements are indicated by 1s.
            need_weights (bool, optional): return the attention weights,
                averaged over heads (default: False).
            attn_mask (ByteTensor, optional): typically used to
                implement causal attention, where the mask prevents the
                attention from looking forward in time (default: None).
            before_softmax (bool, optional): return the raw attention
                weights and values before the attention softmax.
            need_head_weights (bool, optional): return the attention
                weights for each head. Implies *need_weights*. Default:
                return the average attention weights over all heads.
            prune_attn_mask shape (tensor): has shape(1, self.num_heads, 1024, 1024)
        """
        if need_head_weights:
            need_weights = True

        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]

        if (
            self.enable_torch_version
            and not self.onnx_trace
            and incremental_state is None
            and not static_kv
        ):
            assert key is not None and value is not None
            return F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                torch.empty([0]),
                torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                self.training,
                key_padding_mask,
                need_weights,
                attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj.weight,
                k_proj_weight=self.k_proj.weight,
                v_proj_weight=self.v_proj.weight,
            )

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if saved_state is not None and "prev_key" in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert self.encoder_decoder_attention and not self.self_attention
                    key = value = None
        else:
            saved_state = None

        if self.self_attention:
            q = self.q_proj(query)
            k = self.k_proj(query)
            v = self.v_proj(query)
        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.q_proj(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k = self.k_proj(key)
                v = self.v_proj(key)

        else:
            assert key is not None and value is not None
            q = self.q_proj(query)
            k = self.k_proj(key)
            v = self.v_proj(value)
        q *= self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
                )
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [
                        key_padding_mask,
                        key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
                    ],
                    dim=1,
                )

        q = (
            q.contiguous()
            .view(tgt_len, bsz * self.num_heads, self.head_dim)
            .transpose(0, 1)
        )
        if k is not None:
            k = (
                k.contiguous()
                .view(-1, bsz * self.num_heads, self.head_dim)
                .transpose(0, 1)
            )
        if v is not None:
            v = (
                v.contiguous()
                .view(-1, bsz * self.num_heads, self.head_dim)
                .transpose(0, 1)
            )

        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if "prev_key" in saved_state:
                _prev_key = saved_state["prev_key"]
                assert _prev_key is not None
                prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    assert k is not None
                    k = torch.cat([prev_key, k], dim=1)
            if "prev_value" in saved_state:
                _prev_value = saved_state["prev_value"]
                assert _prev_value is not None
                prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    assert v is not None
                    v = torch.cat([prev_value, v], dim=1)
            prev_key_padding_mask: Optional[Tensor] = None
            if "prev_key_padding_mask" in saved_state:
                prev_key_padding_mask = saved_state["prev_key_padding_mask"]
            assert k is not None and v is not None
            key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
                key_padding_mask=key_padding_mask,
                prev_key_padding_mask=prev_key_padding_mask,
                batch_size=bsz,
                src_len=k.size(1),
                static_kv=static_kv,
            )

            saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
            saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
            saved_state["prev_key_padding_mask"] = key_padding_mask
            # In this branch incremental_state is never None
            assert incremental_state is not None
            incremental_state = self._set_input_buffer(incremental_state, saved_state)
        assert k is not None
        src_len = k.size(1)

        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.dim() == 0:
            key_padding_mask = None

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            assert v is not None
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
                )
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [
                        key_padding_mask,
                        torch.zeros(key_padding_mask.size(0), 1).type_as(
                            key_padding_mask
                        ),
                    ],
                    dim=1,
                )

        attn_weights = torch.bmm(q, k.transpose(1, 2))
        attn_weights = MultiheadAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)

        assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if prune_attn_mask is not None:
            if incremental_state is None: #train
                prune_attn_mask = prune_attn_mask.to(torch.bool)[:,:,0:tgt_len, 0:src_len]
            else: #generation
                prune_attn_mask = prune_attn_mask.to(torch.bool)[:,:,src_len+1:src_len+2, 0:src_len]
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights.masked_fill(prune_attn_mask, -32768) #prune_mask is 1 where we want to mask
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            if self.onnx_trace:
                attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
            attn_weights += attn_mask

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
            )
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        if before_softmax:
            return attn_weights, v

        if self.USE_ENTMAX:
            attn_weights_float = entmax15(attn_weights)
        else:
            attn_weights_float = utils.softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)

        attn_weights = attn_weights_float.type_as(attn_weights)
        attn_probs = F.dropout(
            attn_weights_float.type_as(attn_weights),
            p=self.dropout,
            training=self.training,
        )
        self.attn_probs = attn_probs.view(bsz, self.num_heads, tgt_len, src_len) #keep track of attention pattern for pruning experiments
        assert v is not None
        attn = torch.bmm(attn_probs, v)
        assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        if self.onnx_trace and attn.size(1) == 1:
            # when ONNX tracing a single decoder step (sequence length == 1)
            # the transpose is a no-op copy before view, thus unnecessary
            attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
        else:
            attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)
        attn_weights: Optional[Tensor] = None
        if need_weights:
            attn_weights = attn_weights_float.view(
                bsz, self.num_heads, tgt_len, src_len
            ).transpose(1, 0)
            if not need_head_weights:
                # average attention weights over heads
                attn_weights = attn_weights.mean(dim=0)

        return attn, attn_weights
Exemplo n.º 9
0
def log_entmax15(*args, **kwargs):
    return torch.log(entmax15(*args, **kwargs))
Exemplo n.º 10
0
    def forward(
        self,
        query,
        key,
        value,
        key_padding_mask=None,
        incremental_state=None,
        need_weights=True,
        static_kv=False,
        attn_mask=None,
        before_softmax=False,
        need_head_weights=False,
    ):
        """Input shape: Time x Batch x Channel

        Args:
            key_padding_mask (ByteTensor, optional): mask to exclude
                keys that are pads, of shape `(batch, src_len)`, where
                padding elements are indicated by 1s.
            need_weights (bool, optional): return the attention weights,
                averaged over heads (default: False).
            attn_mask (ByteTensor, optional): typically used to
                implement causal attention, where the mask prevents the
                attention from looking forward in time (default: None).
            before_softmax (bool, optional): return the raw attention
                weights and values before the attention softmax.
            need_head_weights (bool, optional): return the attention
                weights for each head. Implies *need_weights*. Default:
                return the average attention weights over all heads.
        """
        if need_head_weights:
            need_weights = True

        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]

        if self.enable_torch_version and not self.onnx_trace and incremental_state is None and not static_kv:
            return F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                torch.empty([0]),
                torch.cat(
                    (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                self.training,
                key_padding_mask,
                need_weights,
                attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj.weight,
                k_proj_weight=self.k_proj.weight,
                v_proj_weight=self.v_proj.weight)

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if 'prev_key' in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert self.encoder_decoder_attention and not self.self_attention
                    key = value = None
        else:
            saved_state = None

        if self.self_attention:
            q = self.q_proj(query)
            k = self.k_proj(query)
            v = self.v_proj(query)
        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.q_proj(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k = self.k_proj(key)
                v = self.v_proj(key)

        else:
            q = self.q_proj(query)
            k = self.k_proj(key)
            v = self.v_proj(value)
        q *= self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat([
                    key_padding_mask,
                    key_padding_mask.new_zeros(key_padding_mask.size(0), 1)
                ],
                                             dim=1)

        q = q.contiguous().view(tgt_len, bsz * self.num_heads,
                                self.head_dim).transpose(0, 1)
        if k is not None:
            k = k.contiguous().view(-1, bsz * self.num_heads,
                                    self.head_dim).transpose(0, 1)
        if v is not None:
            v = v.contiguous().view(-1, bsz * self.num_heads,
                                    self.head_dim).transpose(0, 1)

        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if 'prev_key' in saved_state:
                prev_key = saved_state['prev_key'].view(
                    bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    k = torch.cat((prev_key, k), dim=1)
            if 'prev_value' in saved_state:
                prev_value = saved_state['prev_value'].view(
                    bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    v = torch.cat((prev_value, v), dim=1)
            key_padding_mask = self._append_prev_key_padding_mask(
                key_padding_mask=key_padding_mask,
                prev_key_padding_mask=saved_state.get('prev_key_padding_mask',
                                                      None),
                batch_size=bsz,
                src_len=k.size(1),
                static_kv=static_kv,
            )

            saved_state['prev_key'] = k.view(bsz, self.num_heads, -1,
                                             self.head_dim)
            saved_state['prev_value'] = v.view(bsz, self.num_heads, -1,
                                               self.head_dim)
            saved_state['prev_key_padding_mask'] = key_padding_mask

            self._set_input_buffer(incremental_state, saved_state)

        src_len = k.size(1)

        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.shape == torch.Size(
            []):
            key_padding_mask = None

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])],
                          dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])],
                          dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat([
                    key_padding_mask,
                    torch.zeros(key_padding_mask.size(0),
                                1).type_as(key_padding_mask)
                ],
                                             dim=1)
        if not bmm_fp16_support:
            q = q.float()
            k = k.float()
            v = v.float()
        attn_weights = torch.bmm(q, k.transpose(1, 2))
        if not bmm_fp16_support:
            attn_weights = attn_weights.type_as(query)
        attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len,
                                              bsz)

        assert list(
            attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            if self.onnx_trace:
                attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
            attn_weights += attn_mask

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len)
            attn_weights = attn_weights.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2),
                float('-inf'),
            )
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
                                             src_len)

        if before_softmax:
            return attn_weights, v
        # 1
        if not self.cur_san_active:
            self.div = 0
        if self.div > 0:
            top_k = int(torch.ceil(torch.Tensor([src_len / self.div])))
            if top_k < self.lb:
                top_k = self.lb
                if top_k > src_len:
                    top_k = src_len
        else:
            top_k = -self.div
            if top_k > src_len:
                top_k = src_len
        # 2
        # print('attn_weights ', attn_weights.size())
        if self.entmax:
            from entmax import sparsemax, entmax15, entmax_bisect
            if self.entmax == 1:
                attn_weights = sparsemax(attn_weights.float(),
                                         dim=-1).type_as(attn_weights)
            elif self.entmax == 2:
                attn_weights = entmax15(attn_weights.float(),
                                        dim=-1).type_as(attn_weights)
            elif self.entmax == 3:
                attn_weights_float = entmax_bisect(
                    attn_weights.float(), dim=-1).type_as(attn_weights)
        else:
            if self.div:
                vk, _ = torch.topk(attn_weights, top_k)
                # print(value)
                tk = vk[:, :, -1].unsqueeze(2).expand_as(attn_weights)
                mask_k = torch.lt(attn_weights, tk)
                attn_weights = attn_weights.masked_fill(
                    mask_k, float('-inf')).type_as(attn_weights)
            attn_weights_float = utils.softmax(attn_weights,
                                               dim=-1,
                                               onnx_trace=self.onnx_trace)
        attn_weights = attn_weights_float.type_as(attn_weights)
        attn_probs = F.dropout(attn_weights_float.type_as(attn_weights),
                               p=self.dropout,
                               training=self.training)
        if not bmm_fp16_support:
            attn_probs = attn_probs.float(
            )  # bsz * self.num_heads, tgt_len, src_len
        attn = torch.bmm(attn_probs, v)
        if not bmm_fp16_support:
            attn = attn.type_as(query)
        assert list(
            attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        if (self.onnx_trace and attn.size(1) == 1):
            # when ONNX tracing a single decoder step (sequence length == 1)
            # the transpose is a no-op copy before view, thus unnecessary
            attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
        else:
            attn = attn.transpose(0,
                                  1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)

        if need_weights:
            attn_weights = attn_weights_float.view(bsz, self.num_heads,
                                                   tgt_len,
                                                   src_len).transpose(1, 0)
            if not need_head_weights:
                # average attention weights over heads
                attn_weights = attn_weights.mean(dim=0)
        else:
            attn_weights = None

        return attn, attn_weights
Exemplo n.º 11
0
    def forward(self, y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, step):

        if not self.training and step == 0:
            h_decoder, c_decoder = s_t_1
            s_t_hat = torch.cat((h_decoder.view(
                -1, config.hidden_dim), c_decoder.view(-1, config.hidden_dim)),
                                1)  # B x 2*hidden_dim
            c_t, _, coverage_next = self.attention_network(
                s_t_hat, encoder_outputs, encoder_feature, enc_padding_mask,
                coverage)
            coverage = coverage_next

        y_t_1_embd = self.embedding(y_t_1)
        x = self.x_context(torch.cat((c_t_1, y_t_1_embd), 1))
        lstm_out, s_t = self.lstm(x.unsqueeze(1), s_t_1)

        h_decoder, c_decoder = s_t
        s_t_hat = torch.cat((h_decoder.view(
            -1, config.hidden_dim), c_decoder.view(-1, config.hidden_dim)),
                            1)  # B x 2*hidden_dim
        c_t, attn_dist, coverage_next = self.attention_network(
            s_t_hat, encoder_outputs, encoder_feature, enc_padding_mask,
            coverage)

        if self.training or step > 0:
            coverage = coverage_next

        p_gen = None
        if config.pointer_gen:
            p_gen_input = torch.cat((c_t, s_t_hat, x),
                                    1)  # B x (2*2*hidden_dim + emb_dim)
            p_gen = self.p_gen_linear(p_gen_input)
            p_gen = F.sigmoid(p_gen)

        output = torch.cat((lstm_out.view(-1, config.hidden_dim), c_t),
                           1)  # B x hidden_dim * 3
        output = self.out1(output)  # B x hidden_dim

        output = self.out2(output)  # B x vocab_size

        T = config.temperature

        if config.tsallis_alpha == 1:
            vocab_dist = F.softmax(output / T, dim=1)
        elif config.tsallis_alpha == 1.5:
            vocab_dist = entmax15(output / T, dim=1)
        elif config.tsallis_alpha == 2:
            vocab_dist = sparsemax(output / T, dim=1)

        # debug('vocab_dist', vocab_dist.size())
        if config.DEBUG and config.REC_ENTROPY:

            def get_entropy(t):
                return -torch.sum(torch.log(t) * t, dim=1)

            # vocab_dist_entropy = get_entropy(vocab_dist + config.eps)
            # debug('vocab_dist_entropy', vocab_dist_entropy)
            with open(
                    os.path.join(config.log_root,
                                 'vocab_dist_entropy/last_run.csv'), 'a') as f:

                f.write(','.join(
                    [str(i)
                     for i in vocab_dist.cpu().detach().numpy()]) + '\n')
                # if step == 0:
                #   f.write( str(self.batch_cnt) + '\n')
                #   self.batch_cnt += 1
                # f.write( ','.join([ str(i) for i in vocab_dist_entropy.cpu().detach().numpy().round(4)]  ) + '\n' )
                # if config.entmax_select:
                #   f.write( ','.join([ str(i) for i in p_soft.cpu().detach().numpy().round(4)]  ) + '\n')

        if config.pointer_gen:
            vocab_dist_ = p_gen * vocab_dist
            attn_dist_ = (1 - p_gen) * attn_dist

            if extra_zeros is not None:
                vocab_dist_ = torch.cat([vocab_dist_, extra_zeros], 1)

            final_dist = vocab_dist_.scatter_add(1, enc_batch_extend_vocab,
                                                 attn_dist_)

            # debug("extra_zeros", extra_zeros)
            # debug("vocab_dist_", vocab_dist_)
            # debug("attn_dist", attn_dist)
            # debug("enc_batch_extend_vocab", enc_batch_extend_vocab)
            # debug("final_dist", final_dist.size())
        else:
            final_dist = vocab_dist
        tau = None

        if config.adaptive_sparsemax:
            eps = torch.DoubleTensor([config.eps]).cuda(0)
            activation = torch.sigmoid
            tau = (1 - eps) * activation(
                self.p_sparse_linear(torch.cat((c_t, s_t_hat, x), 1)))
            debug('tau + eps', tau + eps)
            final_dist = sparsemax(final_dist / (tau + eps), dim=-1)
        elif config.use_top_p:
            final_dist = top_p(final_dist, config.top_p)
        # with open('tau.txt','a') as f:
        #   f.write(','.join([ str(i) for i in tau.cpu().detach().numpy().round(4)]) + '\n')
        # debug("top", final_dist.topk(10, -1))
        # debug("entropy", vocab_dist_entropy)

        return final_dist, s_t, c_t, attn_dist, p_gen, coverage, tau