Ejemplo n.º 1
0
    def _test_jacobian(self, input_dim, hidden_dim):
        jacobian = torch.zeros(input_dim, input_dim)
        iaf = InverseAutoregressiveFlow(input_dim, hidden_dim, sigmoid_bias=0.5)

        def nonzero(x):
            return torch.sign(torch.abs(x))

        x = torch.randn(1, input_dim)
        iaf_x = iaf(x)
        analytic_ldt = iaf.log_abs_det_jacobian(x, iaf_x).data.sum()

        for j in range(input_dim):
            for k in range(input_dim):
                epsilon_vector = torch.zeros(1, input_dim)
                epsilon_vector[0, j] = self.epsilon
                iaf_x_eps = iaf(x + epsilon_vector)
                delta = (iaf_x_eps - iaf_x) / self.epsilon
                jacobian[j, k] = float(delta[0, k].data.sum())

        permutation = iaf.arn.get_permutation()
        permuted_jacobian = jacobian.clone()
        for j in range(input_dim):
            for k in range(input_dim):
                permuted_jacobian[j, k] = jacobian[permutation[j], permutation[k]]
        numeric_ldt = torch.sum(torch.log(torch.diag(permuted_jacobian)))
        ldt_discrepancy = np.fabs(analytic_ldt - numeric_ldt)

        diag_sum = torch.sum(torch.diag(nonzero(permuted_jacobian)))
        lower_sum = torch.sum(torch.tril(nonzero(permuted_jacobian), diagonal=-1))

        assert ldt_discrepancy < self.epsilon
        assert diag_sum == float(input_dim)
        assert lower_sum == float(0.0)
Ejemplo n.º 2
0
    def __init__(self, hidden_size, num_inputs, action_space):
        super(Policy, self).__init__()
        self.action_space = action_space
        num_outputs = action_space.shape[0]

        self.bn0 = nn.BatchNorm1d(num_inputs)
        self.bn0.weight.data.fill_(1)
        self.bn0.bias.data.fill_(0)

        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.bn1 = nn.BatchNorm1d(hidden_size)
        self.bn1.weight.data.fill_(1)
        self.bn1.bias.data.fill_(0)

        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.bn2 = nn.BatchNorm1d(hidden_size)
        self.bn2.weight.data.fill_(1)
        self.bn2.bias.data.fill_(0)

        self.V = nn.Linear(hidden_size, 1)
        self.V.weight.data.mul_(0.1)
        self.V.bias.data.mul_(0.1)

        self.mu = nn.Linear(hidden_size, num_outputs)
        self.mu.weight.data.mul_(0.1)
        self.mu.bias.data.mul_(0.1)

        self.L = nn.Linear(hidden_size, num_outputs ** 2)
        self.L.weight.data.mul_(0.1)
        self.L.bias.data.mul_(0.1)

        self.tril_mask = Variable(torch.tril(torch.ones(
            num_outputs, num_outputs), diagonal=-1).unsqueeze(0))
        self.diag_mask = Variable(torch.diag(torch.diag(
            torch.ones(num_outputs, num_outputs))).unsqueeze(0))
Ejemplo n.º 3
0
def batch_tril(bmat, diagonal=0):
    """
    Given a batch of matrices, returns the lower triangular part of each matrix, with
    the other entries set to 0. The argument `diagonal` has the same meaning as in
    `torch.tril`.
    """
    if bmat.dim() == 2:
        return bmat.tril(diagonal=diagonal)
    else:
        return bmat * torch.tril(bmat.new(*bmat.shape[-2:]).fill_(1.0), diagonal=diagonal)
Ejemplo n.º 4
0
def addOrthov2Regularizer(loss,model, regParam, targetLayers) :
    for i in range( len(targetLayers) ) :
        layerParams =   model[targetLayers[i]].named_parameters() 
        for param in layerParams:  # dont regularize bias params
            if 'bias' not in param[0]: 
                W = param[1].t()
                dotproducts = torch.mm( W.t(),  W) # the lower triangle (excluding diagonal) is the dot products between all neurons
                norms = torch.norm(W, dim=0, keepdim = True)
                cosinesimilarities = dotproducts / norms / norms.t()
                C = (  regParam * 0.5) * torch.sum( torch.tril(cosinesimilarities, diagonal =-1)**2 )
                loss += C
Ejemplo n.º 5
0
def decide2(prob_scores, volumes, C, n_samples):
    bs, N = prob_scores.size()
    sample = torch.zeros(bs, N).type(dtype)
    nn.init.uniform(sample)
    sample = sample*F.softmax(prob_scores).data
    _, inds = sample.sort(1, descending=True)
    volumes = volumes.gather(1, inds)
    M = torch.tril(torch.ones(N,N).type(dtype)).unsqueeze(0).expand(bs,N,N)
    sums = torch.bmm(M,volumes.unsqueeze(2)).squeeze(2)
    mask_chosen = sums <= C.unsqueeze(1).expand_as(sums)
    return mask_chosen, inds
Ejemplo n.º 6
0
 def __call__(self, y_pred, y_true=None):
     """
     y_pred should be two projections
     """
     covar_mat = th.abs(th_matrixcorr(y_pred[0].data, y_pred[1].data))
     upper_sum = th.sum(th.triu(covar_mat,1))
     lower_sum = th.sum(th.tril(covar_mat,-1))
     self.anticorr_sum += upper_sum
     self.anticorr_sum += lower_sum
     self.total_count += covar_mat.size(0)*(covar_mat.size(1) - 1)
     return self.anticorr_sum / self.total_count
 def __init__(self, nx, n_ctx, cfg, scale=False):
     super(Attention, self).__init__()
     n_state = nx  # in Attention: n_state=768 (nx=n_embd)
     # [switch nx => n_state from Block to Attention to keep identical to TF implem]
     assert n_state % cfg.n_head == 0
     self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
     self.n_head = cfg.n_head
     self.split_size = n_state
     self.scale = scale
     self.c_attn = Conv1D(n_state * 3, 1, nx)
     self.c_proj = Conv1D(n_state, 1, nx)
     self.attn_dropout = nn.Dropout(cfg.attn_pdrop)
     self.resid_dropout = nn.Dropout(cfg.resid_pdrop)
Ejemplo n.º 8
0
 def __init__(self,
              nx: int,
              n_ctx: int,
              config: TransformerConfig,
              scale: bool = False) -> None:
     super().__init__()
     n_state = nx  # in Attention: n_state=768 (nx=n_embd)
     # [switch nx => n_state from Block to Attention to keep identical to TF implem]
     assert n_state % config.num_heads == 0
     self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
     self.n_head = config.num_heads
     self.split_size = n_state
     self.scale = scale
     self.c_attn = Conv1D(n_state * 3, 1, nx)
     self.c_proj = Conv1D(n_state, 1, nx)
     self.attn_dropout = torch.nn.Dropout(config.attention_dropout_probability)
     self.resid_dropout = torch.nn.Dropout(config.residual_dropout_probability)
Ejemplo n.º 9
0
    def cumulative_average_mask(self, batch_size, inputs_len):
        """
        Builds the mask to compute the cumulative average as described in
        :cite:`DBLP:journals/corr/abs-1805-00631` -- Figure 3

        Args:
            batch_size (int): batch size
            inputs_len (int): length of the inputs

        Returns:
            (FloatTensor):

            * A Tensor of shape ``(batch_size, input_len, input_len)``
        """

        triangle = torch.tril(torch.ones(inputs_len, inputs_len))
        weights = torch.ones(1, inputs_len) / torch.arange(
            1, inputs_len + 1, dtype=torch.float)
        mask = triangle * weights.transpose(0, 1)

        return mask.unsqueeze(0).expand(batch_size, inputs_len, inputs_len)
Ejemplo n.º 10
0
    def _test_jacobian(self, input_dim, hidden_dim, multiplier):
        jacobian = torch.zeros(input_dim, input_dim)
        arn = AutoRegressiveNN(input_dim, hidden_dim, multiplier)

        def nonzero(x):
            return torch.sign(torch.abs(x))

        for output_index in range(multiplier):
            for j in range(input_dim):
                for k in range(input_dim):
                    x = torch.randn(1, input_dim)
                    epsilon_vector = torch.zeros(1, input_dim)
                    epsilon_vector[0, j] = self.epsilon
                    delta = (arn(x + epsilon_vector) - arn(x)) / self.epsilon
                    jacobian[j, k] = float(delta[0, k + output_index * input_dim])

            permutation = arn.get_permutation()
            permuted_jacobian = jacobian.clone()
            for j in range(input_dim):
                for k in range(input_dim):
                    permuted_jacobian[j, k] = jacobian[permutation[j], permutation[k]]

            lower_sum = torch.sum(torch.tril(nonzero(permuted_jacobian), diagonal=0))
            assert lower_sum == float(0.0)
Ejemplo n.º 11
0
 def other_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     c = torch.randint(0, 8, (5, ), dtype=torch.int64)
     e = torch.randn(4, 3)
     f = torch.randn(4, 4, 4)
     size = [0, 1]
     dims = [0, 1]
     return (
         torch.atleast_1d(a),
         torch.atleast_2d(a),
         torch.atleast_3d(a),
         torch.bincount(c),
         torch.block_diag(a),
         torch.broadcast_tensors(a),
         torch.broadcast_to(a, (4)),
         # torch.broadcast_shapes(a),
         torch.bucketize(a, b),
         torch.cartesian_prod(a),
         torch.cdist(e, e),
         torch.clone(a),
         torch.combinations(a),
         torch.corrcoef(a),
         # torch.cov(a),
         torch.cross(e, e),
         torch.cummax(a, 0),
         torch.cummin(a, 0),
         torch.cumprod(a, 0),
         torch.cumsum(a, 0),
         torch.diag(a),
         torch.diag_embed(a),
         torch.diagflat(a),
         torch.diagonal(e),
         torch.diff(a),
         torch.einsum("iii", f),
         torch.flatten(a),
         torch.flip(e, dims),
         torch.fliplr(e),
         torch.flipud(e),
         torch.kron(a, b),
         torch.rot90(e),
         torch.gcd(c, c),
         torch.histc(a),
         torch.histogram(a),
         torch.meshgrid(a),
         torch.lcm(c, c),
         torch.logcumsumexp(a, 0),
         torch.ravel(a),
         torch.renorm(e, 1, 0, 5),
         torch.repeat_interleave(c),
         torch.roll(a, 1, 0),
         torch.searchsorted(a, b),
         torch.tensordot(e, e),
         torch.trace(e),
         torch.tril(e),
         torch.tril_indices(3, 3),
         torch.triu(e),
         torch.triu_indices(3, 3),
         torch.vander(a),
         torch.view_as_real(torch.randn(4, dtype=torch.cfloat)),
         torch.view_as_complex(torch.randn(4, 2)),
         torch.resolve_conj(a),
         torch.resolve_neg(a),
     )
Ejemplo n.º 12
0
    def decode(self, ys, state=None, mems=None, cache=None, incremental=False):
        """Decode function.

        Args:
            ys (LongTensor): `[B, L]`
            state (List): dummy interfance for RNNLM
            mems (List): length `n_layers`, each of which contains a FloatTensor `[B, mlen, d_model]`
            cache (List): length `L`, each of which contains a FloatTensor `[B, L-1, d_model]`
            incremental (bool): ASR decoding mode
        Returns:
            logits (FloatTensor): `[B, L, vocab]`
            out (FloatTensor): `[B, L, d_model]`
            new_cache (List): length `n_layers`, each of which contains a FloatTensor `[B, L, d_model]`

        """
        # for ASR decoding
        if cache is None:
            cache = [None] * self.n_layers  # 1-th to L-th layer

        if mems is None:
            mems = self.init_memory()
            mlen = 0
        else:
            mlen = mems[0].size(1)

        bs, ylen = ys.size()[:2]
        if incremental and cache[0] is not None:
            ylen = cache[0].size(1) + 1

        # Create the self-attention mask
        causal_mask = ys.new_ones(ylen, ylen + mlen).byte()
        causal_mask = torch.tril(causal_mask, diagonal=0 + mlen, out=causal_mask).unsqueeze(0)
        causal_mask = causal_mask.repeat([bs, 1, 1])  # `[B, L, L+mlen]`

        if self.embed_cache is not None:
            out = self.embed_cache[ys]
        else:
            out = self.dropout_emb(self.embed(ys.long()) * self.scale)

        pos_embs = self.pos_emb(ys, mlen=mlen)

        new_mems = [None] * self.n_layers
        new_cache = [None] * self.n_layers
        hidden_states = [out]
        for lth, (mem, layer) in enumerate(zip(mems, self.layers)):
            if incremental and mlen > 0 and mem.size(0) != bs:
                mem = mem.repeat([bs, 1, 1])
            out = layer(out, causal_mask, cache=cache[lth],
                        pos_embs=pos_embs, memory=mem, u_bias=self.u_bias, v_bias=self.v_bias)
            if incremental:
                new_cache[lth] = out
            elif lth < self.n_layers - 1:
                hidden_states.append(out)
                # NOTE: outputs from the last layer is not used for memory
            if not self.training and layer.yy_aws is not None:
                setattr(self, 'yy_aws_layer%d' % lth, tensor2np(layer.yy_aws))
        out = self.norm_out(out)
        if self.adaptive_softmax is None:
            logits = self.output(out)
        else:
            logits = out

        if incremental:
            # NOTE: do not update memory here during ASR decoding
            return logits, out, new_cache
        else:
            # Update memory
            new_mems = self.update_memory(mems, hidden_states)
            return logits, out, new_mems
Ejemplo n.º 13
0
    def _forward(self, dec_inp, mems=None):
        qlen, bsz = dec_inp.size()

        word_emb = self.word_emb(dec_inp)

        mlen = mems[0].size(0) if mems is not None else 0
        klen = mlen + qlen

        if self.same_length:
            all_ones = word_emb.new_ones(qlen, klen)
            mask_len = klen - self.mem_len
            if mask_len > 0:
                mask_shift_len = qlen - mask_len
            else:
                mask_shift_len = qlen
            dec_attn_mask = (torch.triu(all_ones, 1+mlen)
                    + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1
        else:
            dec_attn_mask = torch.triu(
                word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]

        hids = []
        if self.attn_type == 0: # default
            pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, 
                                   dtype=word_emb.dtype)
            if self.clamp_len > 0:
                pos_seq.clamp_(max=self.clamp_len)
            pos_emb = self.pos_emb(pos_seq)

            core_out = self.drop(word_emb)
            pos_emb = self.drop(pos_emb)

            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                mems_i = None if mems is None else mems[i]
                if self.prune_masks is not None:
                    core_out = layer(core_out, pos_emb, self.r_w_bias,
                            self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i, prune_mask=self.prune_masks[i])
                else:
                    core_out = layer(core_out, pos_emb, self.r_w_bias,
                            self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
                hids.append(core_out)
        elif self.attn_type == 1: # learnable
            core_out = self.drop(word_emb)
            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                if self.clamp_len > 0:
                    r_emb = self.r_emb[i][-self.clamp_len :]
                    r_bias = self.r_bias[i][-self.clamp_len :]
                else:
                    r_emb, r_bias = self.r_emb[i], self.r_bias[i]

                mems_i = None if mems is None else mems[i]
                if self.prune_masks is not None:
                    core_out = layer(core_out, r_emb, self.r_w_bias[i],
                            r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i, prune_mask=self.prune_masks[i])
                else:
                    core_out = layer(core_out, r_emb, self.r_w_bias[i],
                            r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
                hids.append(core_out)
        elif self.attn_type == 2: # absolute
            pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
                                   dtype=word_emb.dtype)
            if self.clamp_len > 0:
                pos_seq.clamp_(max=self.clamp_len)
            pos_emb = self.pos_emb(pos_seq)

            core_out = self.drop(word_emb + pos_emb[-qlen:])

            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                mems_i = None if mems is None else mems[i]
                if mems_i is not None and i == 0:
                    mems_i += pos_emb[:mlen]
                if self.prune_masks is not None:
                    core_out = layer(core_out, dec_attn_mask=dec_attn_mask, mems=mems_i, prune_mask=self.prune_masks[i])
                else:
                    core_out = layer(core_out, dec_attn_mask=dec_attn_mask, mems=mems_i)
                hids.append(core_out)
        elif self.attn_type == 3:
            core_out = self.drop(word_emb)

            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                mems_i = None if mems is None else mems[i]
                if mems_i is not None and mlen > 0:
                    cur_emb = self.r_emb[i][:-qlen]
                    cur_size = cur_emb.size(0)
                    if cur_size < mlen:
                        cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)
                        cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
                    else:
                        cur_emb = cur_emb[-mlen:]
                    mems_i += cur_emb.view(mlen, 1, -1)
                core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
                if self.prune_masks is not None:
                    core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
                                     mems=mems_i, prune_mask=self.prune_masks[i])
                else:
                    core_out = layer(core_out, dec_attn_mask=dec_attn_mask, mems=mems_i)
                hids.append(core_out)

        core_out = self.drop(core_out)

        new_mems = self._update_mems(hids, mems, mlen, qlen)

        return core_out, new_mems
Ejemplo n.º 14
0
def get_seq_mask(targets):
    batch_size, steps = targets.size()
    seq_mask = torch.ones([batch_size, steps, steps], device=targets.device)
    seq_mask = torch.tril(seq_mask).bool()
    return seq_mask
Ejemplo n.º 15
0
    def forward(self,
                input_ids=None,
                mems=None,
                head_mask=None,
                inputs_embeds=None):
        r"""
    Return:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.TransfoXLConfig`) and inputs:
        last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the last layer of the model.
        mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
            should not be passed as input ids as they have already been computed.
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

        from transformers import TransfoXLTokenizer, TransfoXLModel
        import torch

        tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
        model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states, mems = outputs[:2]

        """
        # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
        # so we transpose here from shape [bsz, len] to shape [len, bsz]
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time"
            )
        elif input_ids is not None:
            input_ids = input_ids.transpose(0, 1).contiguous()
            qlen, bsz = input_ids.size()
        elif inputs_embeds is not None:
            inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
            qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1]
        else:
            raise ValueError(
                "You have to specify either input_ids or inputs_embeds")

        if mems is None:
            mems = self.init_mems(bsz)

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
        # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(
                    0).unsqueeze(0)
                head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
                head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
            head_mask = head_mask.to(dtype=next(self.parameters(
            )).dtype)  # switch to fload if need + fp16 compatibility
        else:
            head_mask = [None] * self.n_layer

        if inputs_embeds is not None:
            word_emb = inputs_embeds
        else:
            word_emb = self.word_emb(input_ids)

        mlen = mems[0].size(0) if mems is not None else 0
        klen = mlen + qlen
        if self.same_length:
            all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
            mask_len = klen - self.mem_len
            if mask_len > 0:
                mask_shift_len = qlen - mask_len
            else:
                mask_shift_len = qlen
            dec_attn_mask = (torch.triu(all_ones, 1 + mlen) +
                             torch.tril(all_ones, -mask_shift_len))[:, :,
                                                                    None]  # -1
        else:
            dec_attn_mask = torch.triu(word_emb.new_ones((qlen, klen),
                                                         dtype=torch.uint8),
                                       diagonal=1 + mlen)[:, :, None]

        hids = []
        attentions = []
        if self.attn_type == 0:  # default
            pos_seq = torch.arange(klen - 1,
                                   -1,
                                   -1.0,
                                   device=word_emb.device,
                                   dtype=word_emb.dtype)
            if self.clamp_len > 0:
                pos_seq.clamp_(max=self.clamp_len)
            pos_emb = self.pos_emb(pos_seq)

            core_out = self.drop(word_emb)
            pos_emb = self.drop(pos_emb)

            for i, layer in enumerate(self.layers):
                hids.append(core_out)
                mems_i = None if mems is None else mems[i]
                layer_outputs = layer(core_out,
                                      pos_emb,
                                      dec_attn_mask=dec_attn_mask,
                                      mems=mems_i,
                                      head_mask=head_mask[i])
                core_out = layer_outputs[0]
                if self.output_attentions:
                    attentions.append(layer_outputs[1])
        else:  # learnable embeddings and absolute embeddings
            raise NotImplementedError  # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint

        core_out = self.drop(core_out)

        new_mems = self._update_mems(hids, mems, mlen, qlen)

        # We transpose back here to shape [bsz, len, hidden_dim]
        outputs = [core_out.transpose(0, 1).contiguous(), new_mems]
        if self.output_hidden_states:
            # Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
            hids.append(core_out)
            hids = list(t.transpose(0, 1).contiguous() for t in hids)
            outputs.append(hids)
        if self.output_attentions:
            # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
            attentions = list(
                t.permute(2, 3, 0, 1).contiguous() for t in attentions)
            outputs.append(attentions)

        return outputs  # last hidden state, new_mems, (all hidden states), (all attentions)
Ejemplo n.º 16
0
 def make_tgt_mask(self, tgt):
     N, tgt_len = tgt.shape
     tgt_mask = torch.tril(torch.ones(tgt_len, tgt_len)).expand(
         N, 1, tgt_len, tgt_len)
     return tgt_mask.to(self.device)
Ejemplo n.º 17
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        timestamp=None,
        category_ids=None,
        position_ids=None,
        elapsed_time=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    ):
        r"""
        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
            
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        """
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                'You cannot specify both input_ids and inputs_embeds at the same time'
            )
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError(
                'You have to specify either input_ids or inputs_embeds')

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        if attention_mask is None:
            attention_mask = torch.ones(input_shape, device=device)

        # Upper Triangular Mask
        attention_mask = torch.tril(
            torch.matmul(attention_mask[:, :, None], attention_mask[:, None, :]
                         ))  # [batch_size, seq_length, seq_length]

        if category_ids is None:
            category_ids = torch.zeros(input_shape,
                                       dtype=torch.long,
                                       device=device)

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
            attention_mask, input_shape, device)

        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size(
            )
            encoder_hidden_shape = (encoder_batch_size,
                                    encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape,
                                                    device=device)

            # Upper Triangular Mask
            encoder_attention_mask = torch.tril(
                torch.matmul(encoder_attention_mask[:, :, None],
                             encoder_attention_mask[:, None, :])
            )  # [batch_size, seq_length, seq_length]

            encoder_extended_attention_mask = self.invert_attention_mask(
                encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None

        if position_ids is None:
            position_ids = (
                1 -
                (timestamp <= timestamp.roll(1, dims=1)).long()).cumsum(dim=1)

        position_bias = self.compute_bias(position_ids)

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask,
                                       self.config.num_hidden_layers)

        embedding_output = self.embeddings(input_ids=input_ids,
                                           category_ids=category_ids,
                                           timestamp=timestamp,
                                           elapsed_time=elapsed_time,
                                           inputs_embeds=inputs_embeds)
        encoder_output = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            position_bias=position_bias,
            timestamp=timestamp,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
        )
        return encoder_output
Ejemplo n.º 18
0
def make_graph_penalise(diagnoses, scores, batch_size=1000, debug=True, k=3, mode='k_closest', save_edge_values=True):
    print('==> Getting edges')
    if debug:
        diagnoses = diagnoses[:1000]
    no_pts = len(diagnoses)
    diags_per_pt = diagnoses.sum(axis=1)
    diags_per_pt = torch.tensor(diags_per_pt.values).type(torch.ShortTensor)
    del diagnoses

    if save_edge_values:
        edges_val = sparse.lil_matrix((no_pts, no_pts), dtype=np.int16)
    edges = sparse.lil_matrix((no_pts, no_pts), dtype=np.uint8)

    down = torch.split(diags_per_pt.repeat(no_pts, 1), batch_size, dim=0)
    across = torch.split(diags_per_pt.repeat(no_pts, 1).permute(1, 0), batch_size, dim=0)
    scores = scores.fill_diagonal_(0)  # remove self scores on diagonal
    score = torch.split(scores, batch_size, dim=0)
    prev_pts = 0
    for i, (d, a, s) in enumerate(zip(down, across, score)):
        print('==> Processed {} patients'.format(prev_pts))
        total_combined_diags = d + a
        s_pen = 5 * s - total_combined_diags  # the 5 is fairly arbitrary but I don't want to penalise not sharing diagnoses too much
        if mode == 'k_closest':
            k_ = k
        else:
            k_ = 1 # make sure there is at least one edge for each node in the threshold graph
        for patient in range(len(d)):
            k_highest_inds = torch.sort(s_pen[patient].flatten()).indices[-k_:]
            if save_edge_values:
                k_highest_vals = torch.sort(s_pen[patient].flatten()).values[-k_:]
                for i, val in zip(k_highest_inds, k_highest_vals):
                    if val == 0:  # these get removed if val is 0
                        val = 1
                    edges_val[patient + prev_pts, i] = val
            for i in k_highest_inds:
                edges[patient + prev_pts, i] = 1
        prev_pts += batch_size
        if mode == 'threshold':
            scores_lower = torch.tril(s_pen, diagonal=-1)
            if i == 0:  # define threshold
                desired_no_edges = k * len(s_pen)
                threshold_value = torch.sort(scores_lower.flatten()).values[-desired_no_edges]
            # for batch in batch(no_pts, n=10):
            for batch in torch.split(scores_lower, 100, dim=0):
                batch[batch < threshold_value] = 0
            edges[batch_size * i:batch_size * i + len(scores_lower)] = \
                edges[batch_size * i:batch_size * i + len(scores_lower)] + \
                sparse.lil_matrix(scores_lower)

    del scores, score, down, across, d, a, s, total_combined_diags, s_pen

    # make it symmetric again
    edges = edges + edges.transpose()
    if save_edge_values:
        edges_val = edges_val + edges_val.transpose()
        for i, (edge, edge_val) in enumerate(zip(edges, edges_val)):
            edges_val[i, edge.indices] = edge_val.data // edge.data
        edges = edges_val
    edges.setdiag(0)  # remove any left over self edges from patients without any diagnoses (these will be generally matched with others having no diagnoses)
    edges.eliminate_zeros()
    # do upper triangle again and then save
    edges = sparse.tril(edges, k=-1)
    v, u, vals = sparse.find(edges)
    return u, v, vals, k
def subsequent_mask(size: int, device: str = "cpu") -> torch.BoolTensor:
    """Mask out subsequent positions."""
    mask = torch.tril(torch.ones(size, size, device=device,
                                 dtype=torch.bool)).unsqueeze(0)
    return mask
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    return torch.tril(torch.ones(1, size, size, dtype=torch.bool))
    def forward(
        self, query, key, value, mask_future_timesteps=False,
        key_padding_mask=None, use_scalar_bias=False,
    ):
        """Input shape: Time x Batch x Channel
        Self-attention can be implemented by passing in the same arguments for
        query, key and value. Future timesteps can be masked with the
        `mask_future_timesteps` argument. Padding elements can be excluded from
        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
        batch x src_len, where padding elements are indicated by 1s.
        """
        src_len, bsz, out_channels = key.size()
        tgt_len = query.size(0)
        assert list(query.size()) == [tgt_len, bsz, out_channels]
        assert key.size() == value.size()

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.downsample:
            size = bsz
        else:
            size = bsz * self.num_heads

        k = key
        v = value
        q = query
        if self.project_input:
            q = self.in_proj_q(q)
            k = self.in_proj_k(k)
            v = self.in_proj_v(v)
            src_len = k.size()[0]
        q *= self.scaling

        if not self.downsample:
            q = q.view(tgt_len, size, self.head_dim)
            k = k.view(src_len, size, self.head_dim)
            v = v.view(src_len, size, self.head_dim)

        q = q.transpose(0, 1)
        k = k.transpose(0, 1)
        v = v.transpose(0, 1)

        attn_weights = torch.bmm(q, k.transpose(1, 2))
        if mask_future_timesteps:
            assert query.size() == key.size(), \
                'mask_future_timesteps only applies to self-attention'
            attn_weights *= torch.tril(
                attn_weights.data.new([1]).expand(tgt_len, tgt_len).clone(),
                diagonal=-1,
            )[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0)
            attn_weights += torch.triu(
                attn_weights.data.new([-math.inf]).expand(tgt_len, tgt_len).clone(),
                diagonal=0
            )[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0)
        tgt_size = tgt_len
        if use_scalar_bias:
            attn_weights = scalar_bias(attn_weights, 2)
            v = scalar_bias(v, 1)
            tgt_size += 1

        if key_padding_mask is not None:
            # don't attend to padding symbols
            if key_padding_mask.max() > 0:
                if self.downsample:
                    attn_weights = attn_weights.view(bsz, 1, tgt_len, src_len)
                else:
                    attn_weights = attn_weights.view(size, self.num_heads, tgt_len, src_len)
                attn_weights = attn_weights.masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    -math.inf,
                )
                attn_weights = attn_weights.view(size, tgt_len, src_len)
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)

        attn = torch.bmm(attn_weights, v)
        if self.downsample:
            attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.head_dim)
        else:
            attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)

        attn = self.out_proj(attn)

        return attn, attn_weights
Ejemplo n.º 22
0
 def check(self, value):
     return (torch.tril(value) == value).min(-1).min(-1)
Ejemplo n.º 23
0
    def forward(
        self,
        input_ids=None,
        mems=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
        # so we transpose here from shape [bsz, len] to shape [len, bsz]
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time"
            )
        elif input_ids is not None:
            input_ids = input_ids.transpose(0, 1).contiguous()
            qlen, bsz = input_ids.size()
        elif inputs_embeds is not None:
            inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
            qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1]
        else:
            raise ValueError(
                "You have to specify either input_ids or inputs_embeds")

        if mems is None:
            mems = self.init_mems(bsz)

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
        # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(
                    0).unsqueeze(0)
                head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
                head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
            head_mask = head_mask.to(dtype=next(self.parameters(
            )).dtype)  # switch to fload if need + fp16 compatibility
        else:
            head_mask = [None] * self.n_layer

        if inputs_embeds is not None:
            word_emb = inputs_embeds
        else:
            word_emb = self.word_emb(input_ids)

        mlen = mems[0].size(0) if mems is not None else 0
        klen = mlen + qlen
        if self.same_length:
            all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
            mask_len = klen - self.mem_len
            if mask_len > 0:
                mask_shift_len = qlen - mask_len
            else:
                mask_shift_len = qlen
            dec_attn_mask = (torch.triu(all_ones, 1 + mlen) +
                             torch.tril(all_ones, -mask_shift_len))[:, :,
                                                                    None]  # -1
        else:
            dec_attn_mask = torch.triu(word_emb.new_ones((qlen, klen),
                                                         dtype=torch.uint8),
                                       diagonal=1 + mlen)[:, :, None]

        hids = []
        attentions = [] if output_attentions else None
        if self.attn_type == 0:  # default
            pos_seq = torch.arange(klen - 1,
                                   -1,
                                   -1.0,
                                   device=word_emb.device,
                                   dtype=word_emb.dtype)
            if self.clamp_len > 0:
                pos_seq.clamp_(max=self.clamp_len)
            pos_emb = self.pos_emb(pos_seq)

            core_out = self.drop(word_emb)
            pos_emb = self.drop(pos_emb)

            for i, layer in enumerate(self.layers):
                hids.append(core_out)
                mems_i = None if mems is None else mems[i]
                layer_outputs = layer(
                    core_out,
                    pos_emb,
                    dec_attn_mask=dec_attn_mask,
                    mems=mems_i,
                    head_mask=head_mask[i],
                    output_attentions=output_attentions,
                )
                core_out = layer_outputs[0]
                if output_attentions:
                    attentions.append(layer_outputs[1])
        else:  # learnable embeddings and absolute embeddings
            raise NotImplementedError  # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint

        core_out = self.drop(core_out)

        new_mems = self._update_mems(hids, mems, mlen, qlen)

        if output_hidden_states:
            # Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
            hids.append(core_out)
            hids = tuple(t.transpose(0, 1).contiguous() for t in hids)
        else:
            hids = None
        if output_attentions:
            # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
            attentions = tuple(
                t.permute(2, 3, 0, 1).contiguous() for t in attentions)
        # We transpose back here to shape [bsz, len, hidden_dim]
        core_out = core_out.transpose(0, 1).contiguous()

        if not return_dict:
            return tuple(v for v in [core_out, new_mems, hids, attentions]
                         if v is not None)

        return TransfoXLModelOutput(
            last_hidden_state=core_out,
            mems=new_mems,
            hidden_states=hids,
            attentions=attentions,
        )
Ejemplo n.º 24
0
    def generate_title(  # noqa C901
        self,
        abstract="",
        authors=[],
        venue="",
        affiliations=[],
        concepts=[],
        num_beams=1,
        no_repeat_ngram_size=3,
        num_return_sequences=1,
        min_length=10,
        max_length=30,
        device=None,
        early_stopping=False,
        debug=False,
    ):
        """generate paper titles given other information

        Args:
            abstract (str, optional): [paper abstract]. Defaults to ''.
            venue (str, optional): [paper venue]. Defaults to ''.
            authors (list, optional): [paper author]. Defaults to [].
            affiliations (list, optional): [paper affiliations]. Defaults to [].
            concepts (list, optional): [paper concepts]. Defaults to [].
            num_beams (int, optional): [beam search width, notice that this function will run one step of beam search in a batch, which should ensure that your gpu (if using) should be able to hold this number of instances]. Defaults to 1.
            no_repeat_ngram_size (int, optional): [n-grams phrases cannot repeat in title]. Defaults to 3.
            num_return_sequences (int, optional): [number of sequences to return]. Defaults to 1.
            min_length (int, optional): [the minimum length of generated title]. Defaults to 10.
            min_length (int, optional): [the maximum length of generated title]. Defaults to 30.
            early_stopping (bool, optional): [terminate generation while target number of generated sequences reach <EOS>]. Defaults to false.
            device ([type], optional): [device for the inputs, default to cpu]. Defaults to None.
            debug (bool, optional): [if debug is true, the beam search progress will be shown]. Defaults to False.

        Returns:
            [list of (string, float)]: [a list of generated titles with their probablities]
        """
        if num_return_sequences > num_beams:
            raise Exception(
                "num_return_sequences(%d) cannot be larger than num_beams(%d)" % (num_return_sequences, num_beams)
            )

        selected_ngrams = {}
        mask_token_id = self.tokenizer.mask_token_id
        eos_token_id = 1
        token_type_id = 0

        (
            input_ids,
            input_masks,
            token_type_ids,
            masked_lm_labels,
            position_ids,
            position_ids_second,
            masked_positions,
            num_spans,
        ) = self.build_inputs(
            title="[CLS] [SEP]",
            abstract=abstract,
            venue=venue,
            authors=authors,
            concepts=concepts,
            affiliations=affiliations,
            decode_span_type="TEXT",
            decode_span_length=0,
            max_seq_length=512,
            mask_propmt_text="",
        )

        context_length = len(input_ids)
        num_spans = 0
        decode_pos = 1
        decode_postion_ids_second = 1
        for i in range(1, context_length):
            if token_type_ids[i] == 0:
                position_ids_second[i] = i + 1

        input_ids.insert(decode_pos, mask_token_id)
        token_type_ids.insert(decode_pos, token_type_id)
        position_ids.insert(decode_pos, num_spans)
        position_ids_second.insert(decode_pos, decode_postion_ids_second)
        masked_lm_labels.insert(decode_pos, self.tokenizer.cls_token_id)

        q = [(input_ids, 0)]
        selected_entities = []

        def tensorize(x):
            return torch.LongTensor(x).to(device or "cpu")

        while True:
            batch_input_ids = tensorize([_input_ids for _input_ids, _ in q])
            batch_token_type_ids = tensorize([token_type_ids for _ in q])

            current_total_length = batch_input_ids.shape[1]
            current_entity_length = current_total_length - context_length

            batch_attention_mask = torch.ones((current_total_length, current_total_length))
            batch_attention_mask[
                decode_pos - current_entity_length + 1 : decode_pos + 1,
                decode_pos - current_entity_length + 1 : decode_pos + 1,
            ] = torch.tril(
                batch_attention_mask[
                    decode_pos - current_entity_length + 1 : decode_pos + 1,
                    decode_pos - current_entity_length + 1 : decode_pos + 1,
                ]
            )
            batch_attention_mask = batch_attention_mask.unsqueeze(0).repeat(len(q), 1, 1).to(device or "cpu")

            batch_position_ids = tensorize([position_ids for _ in q])
            batch_position_ids_second = tensorize([position_ids_second for _ in q])
            batch_masked_lm_labels = tensorize([masked_lm_labels for _ in q])
            sequence_output, pooled_output = self.bert.forward(
                input_ids=batch_input_ids,
                token_type_ids=batch_token_type_ids,
                attention_mask=batch_attention_mask,
                output_all_encoded_layers=False,
                checkpoint_activations=False,
                position_ids=batch_position_ids,
                position_ids_second=batch_position_ids_second,
            )
            masked_token_indexes = torch.nonzero((batch_masked_lm_labels + 1).view(-1)).view(-1)
            prediction_scores, _ = self.cls(sequence_output, pooled_output, masked_token_indexes)
            prediction_scores = torch.nn.functional.log_softmax(prediction_scores, dim=1)
            # surpress existing n-grams
            for idx, (_input_ids, _) in enumerate(q):
                if current_entity_length >= no_repeat_ngram_size:
                    prefix_key = tuple(_input_ids[decode_pos - no_repeat_ngram_size + 1 : decode_pos])
                    for token_id in selected_ngrams.get(prefix_key, set()):
                        prediction_scores[idx, token_id] = -10000
                prefix_key = tuple(_input_ids[decode_pos - current_entity_length : decode_pos])
                if prefix_key in selected_ngrams:
                    for token_id in selected_ngrams.get(prefix_key, set()):
                        prediction_scores[idx, token_id] = -10000
                if current_entity_length <= min_length:
                    prediction_scores[idx, eos_token_id] = -10000
                prediction_scores[idx, _input_ids[decode_pos]] = -10000

            decode_pos += 1
            _q = []
            log_probs, indices = torch.topk(prediction_scores, k=num_beams)
            for idx, (_input_ids, _last_logprob) in enumerate(q):
                for k in range(log_probs.shape[1]):
                    new_input_ids = _input_ids.copy()
                    new_input_ids.insert(decode_pos, indices[idx, k].item())
                    _q.append((new_input_ids, _last_logprob + log_probs[idx, k].item()))

            q = []
            for _input_ids, _last_logprob in _q:
                prefix_key = None
                if current_entity_length >= no_repeat_ngram_size:
                    prefix_key = tuple(_input_ids[decode_pos - no_repeat_ngram_size + 1 : decode_pos])
                    if prefix_key not in selected_ngrams:
                        selected_ngrams[prefix_key] = set()
                    selected_ngrams[prefix_key].add(_input_ids[decode_pos])
                if _input_ids[decode_pos] == eos_token_id:
                    selected_entities.append((_input_ids, _last_logprob))
                else:
                    q.append((_input_ids, _last_logprob))
            q.sort(key=lambda tup: tup[-1], reverse=True)
            selected_entities.sort(key=lambda tup: tup[-1], reverse=True)
            q = q[:num_beams]
            if current_entity_length >= max_length + 2:
                break
            if len(selected_entities) >= num_return_sequences:
                if early_stopping or len(q) == 0 or q[0][-1] <= selected_entities[num_return_sequences - 1][-1]:
                    break

            token_type_ids.insert(decode_pos, token_type_id)
            position_ids.insert(decode_pos, num_spans)
            position_ids_second.insert(decode_pos, decode_postion_ids_second)
            masked_lm_labels[decode_pos - 1] = -1
            masked_lm_labels.insert(decode_pos, self.tokenizer.cls_token_id)

            if debug:
                self.print_oag_instance(
                    input_ids=batch_input_ids[0].cpu().detach().numpy(),
                    token_type_ids=batch_token_type_ids[0].cpu().detach().numpy(),
                    input_masks=batch_attention_mask[0].cpu().detach().numpy(),
                    masked_lm_labels=batch_masked_lm_labels[0].cpu().detach().numpy(),
                    position_ids=batch_position_ids[0].cpu().detach().numpy(),
                    position_ids_second=batch_position_ids_second[0].cpu().detach().numpy(),
                    predictions=torch.topk(prediction_scores, k=5, dim=1).indices.cpu().detach().numpy(),
                )
                input("== Press Enter for next step ==")

        results = []
        for seq, logprob in selected_entities[:num_return_sequences]:
            token_ids = []
            for _id in seq[decode_pos - current_entity_length + 1 : decode_pos]:
                if _id != eos_token_id:
                    token_ids.append(_id)
                else:
                    break
            results.append((self._convert_token_ids_to_text(token_ids), logprob))
        return results
Ejemplo n.º 25
0
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        """
        Args:
          input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)
        Returns:
          (`FloatTensor`, `FloatTensor`):
          * Computed vector `[batch x tgt_len x dim]`
          * Attention distribtutions for each query
             `[batch x tgt_len x src_len]`
        """

        # one step input
        assert source.dim() == 3
        one_step = True if source.size(1) == 1 else False

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.data.masked_fill_(~mask, -float('inf'))

        # We adopt coverage attn described in Paulus et al., 2018
        # REF: https://arxiv.org/abs/1705.04304
        if self._coverage:
            maxes = torch.max(align, 2, keepdim=True)[0]
            exp_score = torch.exp(align - maxes)

            if one_step:
                if coverage is None:
                    # t = 1 in Eq(3) from Paulus et al., 2018
                    unnormalized_score = exp_score
                else:
                    # t = otherwise in Eq(3) from Paulus et al., 2018
                    assert coverage.dim() == 3  # B x 1 x slen
                    unnormalized_score = exp_score.div(coverage + 1e-20)
            else:
                multiplier = torch.tril(torch.ones(target_l - 1, target_l - 1))
                multiplier = multiplier.unsqueeze(0).expand(
                    batch, *multiplier.size())
                multiplier = torch.autograd.Variable(multiplier)
                multiplier = multiplier.cuda() if align.is_cuda else multiplier

                penalty = torch.bmm(multiplier,
                                    exp_score[:, :-1, :])  # B x tlen-1 x slen
                no_penalty = torch.ones_like(penalty[:, -1, :])  # B x slen
                penalty = torch.cat([no_penalty.unsqueeze(1), penalty],
                                    dim=1)  # B x tlen x slen
                assert exp_score.size() == penalty.size()
                unnormalized_score = exp_score.div(penalty + 1e-20)

            # Eq.(4) from Paulus et al., 2018
            align_vectors = unnormalized_score.div(
                unnormalized_score.sum(2, keepdim=True))

        # Softmax to normalize attention weights
        else:
            align_vectors = self.softmax(align.view(batch * target_l,
                                                    source_l))
            align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2)
        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = self.tanh(attn_h)

        # Check output sizes
        batch_, target_l_, dim_ = attn_h.size()
        aeq(target_l, target_l_)
        aeq(batch, batch_)
        aeq(dim, dim_)
        batch_, target_l_, source_l_ = align_vectors.size()
        aeq(target_l, target_l_)
        aeq(batch, batch_)
        aeq(source_l, source_l_)

        covrage_vector = None
        if self._coverage and one_step:
            covrage_vector = exp_score  # B x 1 x slen

        return attn_h, align_vectors, covrage_vector
Ejemplo n.º 26
0
    def forward(self, query, key, value, mask):
        """Forward of 'Dynamic Convolution'.

        This function takes query, key and value but uses only quert.
        This is just for compatibility with self-attention layer (attention.py)

        Args:
            query (torch.Tensor): (batch, time1, d_model) input tensor
            key (torch.Tensor): (batch, time2, d_model) NOT USED
            value (torch.Tensor): (batch, time2, d_model) NOT USED
            mask (torch.Tensor): (batch, time1, time2) mask

        Return:
            x (torch.Tensor): (batch, time1, d_model) ouput

        """
        # linear -> GLU -- -> lightconv -> linear
        #               \        /
        #                 Linear
        x = query
        B, T, C = x.size()
        H = self.wshare
        k = self.kernel_size

        # first liner layer
        x = self.linear1(x)

        # GLU activation
        x = self.act(x)

        # get kernel of convolution
        weight = self.linear_weight(x)  # B x T x kH
        weight = F.dropout(weight, self.dropout_rate, training=self.training)
        weight = weight.view(B, T, H, k).transpose(1, 2).contiguous()  # B x H x T x k
        weight_new = torch.zeros(B * H * T * (T + k - 1), dtype=weight.dtype)
        weight_new = weight_new.view(B, H, T, T + k - 1).fill_(float("-inf"))
        weight_new = weight_new.to(x.device)  # B x H x T x T+k-1
        weight_new.as_strided(
            (B, H, T, k), ((T + k - 1) * T * H, (T + k - 1) * T, T + k, 1)
        ).copy_(weight)
        weight_new = weight_new.narrow(-1, int((k - 1) / 2), T)  # B x H x T x T(k)
        if self.use_kernel_mask:
            kernel_mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0)
            weight_new = weight_new.masked_fill(kernel_mask == 0.0, float("-inf"))
        weight_new = F.softmax(weight_new, dim=-1)
        self.attn = weight_new
        weight_new = weight_new.view(B * H, T, T)

        # convolution
        x = x.transpose(1, 2).contiguous()  # B x C x T
        x = x.view(B * H, int(C / H), T).transpose(1, 2)
        x = torch.bmm(weight_new, x)  # BH x T x C/H
        x = x.transpose(1, 2).contiguous().view(B, C, T)

        if self.use_bias:
            x = x + self.bias.view(1, -1, 1)
        x = x.transpose(1, 2)  # B x T x C

        if mask is not None and not self.use_kernel_mask:
            mask = mask.transpose(-1, -2)
            x = x.masked_fill(mask == 0, 0.0)

        # second linear layer
        x = self.linear2(x)
        return x
Ejemplo n.º 27
0
    def forward_att(self, eouts, elens, ys, return_logits=False):
        """Compute XE loss for the sequence-to-sequence model.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            ys (list): A list of length `[B]`, which contains a list of size `[L]`
            return_logits (bool): return logits for knowledge distillation
        Returns:
            loss (FloatTensor): `[1]`
            acc (float):
            ppl (float):

        """
        bs = eouts.size(0)

        # Append <sos> and <eos>
        eos = eouts.new_zeros(1).fill_(self.eos).long()
        ys = [
            np2tensor(np.fromiter(y[::-1] if self.bwd else y, dtype=np.int64),
                      self.device_id) for y in ys
        ]
        ylens = np2tensor(
            np.fromiter([y.size(0) + 1 for y in ys],
                        dtype=np.int32))  # +1 for <eos>
        ys_in_pad = pad_list([torch.cat([eos, y], dim=0) for y in ys],
                             self.pad)
        ys_out_pad = pad_list([torch.cat([y, eos], dim=0) for y in ys],
                              self.pad)

        # Create the self-attention mask
        bs, ymax = ys_in_pad.size()[:2]
        yy_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).expand(
            bs, ymax, ymax)
        yy_mask = yy_mask.unsqueeze(1).expand(bs, self.attn_n_heads, ymax,
                                              ymax)
        subsequent_mask = torch.tril(yy_mask.new_ones((ymax, ymax)).byte(),
                                     diagonal=0)
        subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).expand(
            bs, self.attn_n_heads, ymax, ymax)
        yy_mask = yy_mask & subsequent_mask

        # Create the source-target mask
        xmax = eouts.size(1)
        x_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).expand(
            bs, ymax, xmax)
        y_mask = make_pad_mask(ylens, self.device_id).unsqueeze(2).expand(
            bs, ymax, xmax)
        xy_mask = (x_mask * y_mask).unsqueeze(1).expand(
            bs, self.attn_n_heads, ymax, xmax)

        ys_emb = self.pos_enc(self.embed(ys_in_pad))
        for l in range(self.n_layers):
            ys_emb, yy_aws, xy_aws = self.layers[l](ys_emb, yy_mask, eouts,
                                                    xy_mask)
            if not self.training:
                setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws))
                setattr(self, 'xy_aws_layer%d' % l, tensor2np(xy_aws))
        logits = self.norm_out(ys_emb)
        if self.adaptive_softmax is None:
            logits = self.output(logits)
        if return_logits:
            return logits

        # Compute XE sequence loss
        if self.adaptive_softmax is None:
            if self.lsm_prob > 0 and self.training:
                # Label smoothing
                loss = cross_entropy_lsm(logits.view((-1, logits.size(2))),
                                         ys_out_pad.view(-1), self.lsm_prob,
                                         self.pad)
            else:
                loss = F.cross_entropy(logits.view((-1, logits.size(2))),
                                       ys_out_pad.view(-1),
                                       ignore_index=self.pad,
                                       size_average=True)

            # Focal loss
            if self.focal_loss_weight > 0:
                fl = focal_loss(logits,
                                ys_out_pad,
                                ylens,
                                alpha=self.focal_loss_weight,
                                gamma=self.focal_loss_gamma)
                loss = loss * (
                    1 - self.focal_loss_weight) + fl * self.focal_loss_weight
        else:
            loss = self.adaptive_softmax(logits.view((-1, logits.size(2))),
                                         ys_out_pad.view(-1)).loss

        # Compute token-level accuracy in teacher-forcing
        if self.adaptive_softmax is None:
            acc = compute_accuracy(logits, ys_out_pad, self.pad)
        else:
            acc = compute_accuracy(
                self.adaptive_softmax.log_prob(
                    logits.view((-1, logits.size(2)))), ys_out_pad, self.pad)
        ppl = min(np.exp(loss.item()), np.inf)

        # scale loss for CTC
        loss *= ylens.float().mean()

        return loss, acc, ppl
Ejemplo n.º 28
0
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
    """Loads a data file into a list of `InputBatch`s."""
    # Old chinese data does not have '\t'. This is for adaption.
    for i in range(len(label_list)):
        if '\t' not in label_list[i]:
            label_list[i]=label_list[i][0]+'\t'+label_list[i][1:]

    label_map = {label: i for i, label in enumerate(label_list)}
    label_map['_'] = -1
    label_count = [0]*len(label_list)

    # label mask (mask the classes which are not candidates)
    label_word = {label.split('\t')[0]: [] for label in label_list}
    for label in label_list:
        label_word[label.split('\t')[0]].append(label_map[label])
    masks = torch.ones((len(label_list), len(label_list))).byte()
    for i, label in enumerate(label_list):
        masks[i, label_word[label.split('\t')[0]]] = 0
    masks = torch.cat([masks.unsqueeze(0) for _ in range(8)])
    # print(masks.size(),masks)

    # hybrid attention
    attention_mask = torch.ones(12, max_seq_length, max_seq_length, dtype=torch.long)
    # left attention
    attention_mask[:2, :, :] = torch.tril(torch.ones(max_seq_length, max_seq_length, dtype=torch.long))
    # right attention
    attention_mask[2:4, :, :] = torch.triu(torch.ones(max_seq_length, max_seq_length, dtype=torch.long))
    # local attention, window size = 3
    attention_mask[4:6, :, :] = torch.triu(
        torch.tril(torch.ones(max_seq_length, max_seq_length, dtype=torch.long), 1), -1)
    attention_mask = torch.cat([attention_mask.unsqueeze(0) for _ in range(8)])

    features = []
    for (ex_index, example) in enumerate(examples):
        if ex_index % 100000 == 0:
            print(ex_index)
        tokens_a = example.text

        # Account for [CLS] and [SEP] with "- 2"
        if len(tokens_a) > max_seq_length - 2:
            tokens_a = tokens_a[:(max_seq_length - 2)]

        # The convention in BERT is:
        # (a) For sequence pairs:
        #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
        #  type_ids: 0   0  0    0    0     0       0 0    1  1  1  1   1 1
        # (b) For single sequences:
        #  tokens:   [CLS] the dog is hairy . [SEP]
        #  type_ids: 0   0   0   0  0     0 0
        #
        # Where "type_ids" are used to indicate whether this is the first
        # sequence or the second sequence. The embedding vectors for `type=0` and
        # `type=1` were learned during pre-training and are added to the wordpiece
        # embedding vector (and position vector). This is not *strictly* necessary
        # since the [SEP] token unambigiously separates the sequences, but it makes
        # it easier for the model to learn the concept of sequences.
        #
        # For the polyphony classification task, the polyphony vector is
        # used as as the "sentence vector". Note that this only makes sense because
        # the entire model is fine-tuned.
        tokens = ["[CLS]"] + tokens_a + ["[SEP]"]

        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        assert len(tokens) == len(input_ids)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)
        # [CLS] + [tokens] + [SEP]
        label_ids = [-1] * max_seq_length

        for i, l in example.label:
            try:
                if '\t' not in l:
                    l=l[0]+'\t'+l[1:]
                assert tokens[i + 1] == l.split('\t')[0]
            except Exception as e:
                print(e)
                print(tokens, i, l)
                continue
            else:
                label_ids[i + 1] = label_map[l]
                label_count[label_map[l]]+=1
        # Zero-pad up to the sequence length.
        padding = [0] * (max_seq_length - len(input_ids))
        input_ids += padding
        input_mask += padding
        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(label_ids) == max_seq_length

        label_pos = example.position + 1  # First token is [cls]
        assert label_pos < max_seq_length
        # assert tokens[label_pos]==example.label[-1][1][0]

        #polyphony character
        char = example.char



        if ex_index < 5:
            logger.info("*** Example ***")
            logger.info("guid: %s" % (example.guid))
            logger.info("tokens: %s" % " ".join(
                [str(x) for x in tokens]))
            logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
            logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
            logger.info("label: %s (id = %s)" % (str(example.label), str(label_ids)))
            logger.info("label position: %s" % (str(label_pos)))
            logger.info("character: %s" % (char))

        features.append(
            InputFeatures(input_ids=input_ids,
                          input_mask=input_mask,
                          label_ids=label_ids,
                          label_pos=label_pos,
                          char=char))
    # classification weight, for balancing the classes
    weight = [(max(label_count) / (lc + 100))**1 for lc in label_count]
    print(weight)
    weight = torch.FloatTensor([weight] * 8)
    return features, masks, weight, attention_mask
Ejemplo n.º 29
0
    def greedy(self,
               eouts,
               elens,
               max_len_ratio,
               exclude_eos=False,
               idx2token=None,
               refs_id=None,
               speakers=None,
               oracle=False):
        """Greedy decoding in the inference stage (used only for evaluation during training).

        Args:
            eouts (FloatTensor): `[B, T, enc_units]`
            elens (IntTensor): `[B]`
            max_len_ratio (int): maximum sequence length of tokens
            exclude_eos (bool):
            idx2token ():
            refs_id (list):
            speakers (list):
            oracle (bool):
        Returns:
            best_hyps (list): A list of length `[B]`, which contains arrays of size `[L]`
            aw (list): A list of length `[B]`, which contains arrays of size `[L, T]`

        """
        bs, xmax = eouts.size()[:2]

        # Start from <sos> (<eos> in case of the backward decoder)
        ys_all = eouts.new_zeros(bs, 1).fill_(self.eos).long()

        # TODO(hirofumi): Create the source-target mask for batch decoding

        best_hyps_batch = []
        ylens = torch.zeros(bs).int()
        yy_aws_tmp = [None] * bs
        xy_aws_tmp = [None] * bs
        eos_flags = [False] * bs
        for t in range(int(np.floor(xmax * max_len_ratio)) + 1):
            # Create the self-attention mask
            yy_mask = make_pad_mask(ylens + 1,
                                    self.device_id).unsqueeze(1).expand(
                                        bs, t + 1, t + 1)
            yy_mask = yy_mask.unsqueeze(1).expand(bs, self.attn_n_heads, t + 1,
                                                  t + 1)
            subsequent_mask = torch.tril(yy_mask.new_ones(
                (t + 1, t + 1)).byte(),
                                         diagonal=0)
            subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).expand(
                bs, self.attn_n_heads, t + 1, t + 1)
            yy_mask = yy_mask & subsequent_mask

            # Create the source-target mask
            xmax = eouts.size(1)
            x_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).expand(
                bs, t + 1, xmax)
            y_mask = make_pad_mask(ylens + 1,
                                   self.device_id).unsqueeze(2).expand(
                                       bs, t + 1, xmax)
            xy_mask = (x_mask * y_mask).unsqueeze(1).expand(
                bs, self.attn_n_heads, t + 1, xmax)

            out = self.pos_enc(self.embed(ys_all))
            for l in range(self.n_layers):
                out, yy_aws, xy_aws = self.layers[l](out, yy_mask, eouts,
                                                     xy_mask)
            out = self.norm_out(out)

            # Pick up 1-best
            y = self.output(out).argmax(-1)[:, -1:]
            best_hyps_batch += [y]

            # Count lengths of hypotheses
            for b in range(bs):
                if not eos_flags[b]:
                    if y[b].item() == self.eos:
                        eos_flags[b] = True
                        yy_aws_tmp[b] = yy_aws[b:b + 1]  # TODO: fix this
                        xy_aws_tmp[b] = xy_aws[b:b + 1]
                    ylens[b] += 1
                    # NOTE: include <eos>

            # Break if <eos> is outputed in all mini-bs
            if sum(eos_flags) == bs:
                break

            ys_all = torch.cat([ys_all, y], dim=-1)

        # Concatenate in L dimension
        best_hyps_batch = torch.cat(best_hyps_batch, dim=1)
        # xy_aws_tmp = torch.stack(xy_aws_tmp, dim=0)

        # Convert to numpy
        best_hyps_batch = tensor2np(best_hyps_batch)
        # xy_aws_tmp = tensor2np(xy_aws_tmp)

        # if self.score.attn_n_heads > 1:
        #     xy_aws_tmp = xy_aws_tmp[:, :, :, 0]
        #     # TODO(hirofumi): fix for MHA

        # Truncate by the first <eos> (<sos> in case of the backward decoder)
        if self.bwd:
            # Reverse the order
            best_hyps = [
                best_hyps_batch[b, :ylens[b]][::-1] for b in range(bs)
            ]
            # aws = [xy_aws_tmp[b, :ylens[b]][::-1] for b in range(bs)]
        else:
            best_hyps = [best_hyps_batch[b, :ylens[b]] for b in range(bs)]
            # aws = [xy_aws_tmp[b, :ylens[b]] for b in range(bs)]

        # Exclude <eos> (<sos> in case of the backward decoder)
        if exclude_eos:
            if self.bwd:
                best_hyps = [
                    best_hyps[b][1:] if eos_flags[b] else best_hyps[b]
                    for b in range(bs)
                ]
            else:
                best_hyps = [
                    best_hyps[b][:-1] if eos_flags[b] else best_hyps[b]
                    for b in range(bs)
                ]

        # return best_hyps, aws
        return best_hyps, None
Ejemplo n.º 30
0
def tril(mat, k=0):
    return torch.tril(mat, diagonal=k)
Ejemplo n.º 31
0
def get_masks_and_position_ids(data,
                               eod_token,
                               reset_position_ids,
                               reset_attention_mask,
                               loss_mask=None,
                               attention_mask=None,
                               set_loss_mask=False,
                               mem_length=None):
    # Extract batch size and sequence length.
    batch_size, seq_length = data.size()

    # Attention mask (lower triangular).
    if mem_length:
        if attention_mask is None:
            attention_mask = torch.ones(
                (1, seq_length, seq_length + mem_length), device=data.device)
        attention_mask = torch.tril(
            torch.triu(attention_mask, 1 - seq_length + mem_length),
            mem_length)
    else:
        if reset_attention_mask:
            att_mask_batch = batch_size
        else:
            att_mask_batch = 1
        if attention_mask is None:
            attention_mask = torch.ones(
                (att_mask_batch, seq_length, seq_length), device=data.device)
        attention_mask = torch.tril(attention_mask)
    attention_mask = attention_mask.unsqueeze(1)

    # Loss mask.
    if loss_mask is None:
        loss_mask = torch.ones(data.size(),
                               dtype=torch.float,
                               device=data.device)

    # Position ids.
    position_ids = torch.arange(seq_length,
                                dtype=torch.long,
                                device=data.device)
    position_ids = position_ids.unsqueeze(0).expand_as(data)
    if set_loss_mask:
        loss_mask[data == eod_token] = 0.0
    # We need to clone as the ids will be modifed based on batch index.
    if reset_position_ids:
        position_ids = position_ids.clone()

    if reset_position_ids or reset_attention_mask:
        # Loop through the batches:
        for b in range(batch_size):

            # Find indecies where EOD token is.
            eod_index = position_ids[b, data[b] == eod_token]
            # Detach indecies from positions if going to modify positions.
            if reset_position_ids:
                eod_index = eod_index.clone()

            # Loop through EOD indecies:
            prev_index = 0
            for j in range(eod_index.size()[0]):
                i = eod_index[j]
                # Mask attention loss.
                if reset_attention_mask:
                    attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
                # Reset positions.
                if reset_position_ids:
                    position_ids[b, (i + 1):] -= (i + 1 - prev_index)
                    prev_index = i + 1

    return attention_mask, loss_mask, position_ids
Ejemplo n.º 32
0
 def forward(self, x):
     L = torch.tril(self.L, diagonal = -1) + torch.diag(torch.ones(self.dim))
     U = torch.triu(self.U, diagonal = 1)
     z = x @ self.P @ L @ (U + torch.diag(self.S))
     log_det = torch.sum(torch.log(torch.abs(self.S)))
     return z, log_det
Ejemplo n.º 33
0
    def generate(self, input_ids, attention_mask, decoder_start_token_id,
                 no_repeat_ngram_size, *args, **kwargs):
        """
        Args:
            input_ids: the sequence to the encode text
            attention_mask: the attention mask for input tokens
            decoder_start_token_id: begin of sentence token id
            no_repeat_ngram_size: no_repeat_ngram_size for beam search
            max_gen_length: max length for beam search
            min_gen_length: min length for beam search
            repetition_penalty: repetition_penalty for beam search
            num_beams: beam size for beam search
            num_return_sequences: num for return sequence for beam search
        """
        max_seq_length = kwargs.pop("max_length", 48)
        min_seq_length = kwargs.pop("min_gen_length", 0)
        repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
        no_repeat_ngram_size = no_repeat_ngram_size
        length_penalty = kwargs.pop("length_penalty", 1.0)
        self.num_beams = kwargs.pop("num_beams", 5)
        num_return_sequences = kwargs.pop("num_return_sequences", 1)
        src_token, src_mask1 = input_ids, attention_mask
        batch_size = src_token.size(0)
        src_len = src_token.size(1)
        total_seq_length = max_seq_length + src_len + 1
        src_mask = src_mask1[:, None, :].repeat(1, total_seq_length, 1)
        tgt_mask = torch.zeros(batch_size, total_seq_length,
                               max_seq_length + 1).to(src_mask)
        tri_mask = torch.ones(batch_size, total_seq_length,
                              max_seq_length + 1).to(src_mask)
        tgt_mask[:, src_len:, :] = torch.tril(tri_mask[:, src_len:, :])
        tgt_mask[:, :, 0] = 0
        src_mask = torch.cat((src_mask, tgt_mask), dim=-1)
        src_seg = torch.tensor([self.config.source_type_id] *
                               src_len).to(src_token)
        src_seg = src_seg[None, :].repeat(batch_size, 1)
        src_pos0 = torch.ones(batch_size, max_seq_length + 1).to(input_ids)
        src_pos0[:, 0] = 0
        src_pos = torch.cat((input_ids, src_pos0.to(input_ids)), dim=-1).ne(0)
        src_pos = torch.cumsum(src_pos, dim=-1) - 1
        self.src_state = dict({
            "src_len": src_len,
            "src_token": src_token,
            "src_seg": src_seg,
            "src_mask": src_mask,
            "src_pos": src_pos,
        })
        dec_seg = [self.config.source_type_id
                   ] + [self.config.target_type_id] * max_seq_length
        self.dec_seg = (torch.tensor(
            dec_seg, dtype=torch.long,
            device=src_token.device).unsqueeze(0).repeat(
                src_token.size(0) * self.num_beams, 1))
        self.dec_mask_token = (torch.from_numpy(
            np.array([self.config.mask_token_id
                      ])).repeat([batch_size * self.num_beams
                                  ]).unsqueeze(-1).to(src_token.device))
        if decoder_start_token_id is not None:
            self.config.bos_token_id = decoder_start_token_id
        bos_token = (torch.from_numpy(np.array([self.config.bos_token_id
                                                ])).repeat([batch_size
                                                            ]).unsqueeze(-1))
        if torch.cuda.is_available():
            bos_token = bos_token.cuda()

        batch_hyp = super().generate(
            bos_token,
            max_length=max_seq_length - 1,
            min_length=min_seq_length,
            do_sample=False,
            num_beams=self.num_beams,
            no_repeat_ngram_size=no_repeat_ngram_size,
            length_penalty=length_penalty,
            repetition_penalty=repetition_penalty,
            bos_token_id=self.config.bos_token_id,
            pad_token_id=self.config.pad_token_id,
            eos_token_id=self.config.eos_token_id,
            num_return_sequences=num_return_sequences,
        )

        batch_hyp = batch_hyp.reshape(batch_size, num_return_sequences, -1)
        batch_hyp = batch_hyp[:, 0, :]
        return batch_hyp
Ejemplo n.º 34
0
def conditional_corrcoeff(
    density: Any,
    limits: Tensor,
    condition: Tensor,
    subset: Optional[List[int]] = None,
    resolution: int = 50,
    warn_about_deprecation: bool = True,
) -> Tensor:
    r"""
    Returns the conditional correlation matrix of a distribution.

    To compute the conditional distribution, we condition all but two parameters to
    values from `condition`, and then compute the Pearson correlation
    coefficient $\rho$ between the remaining two parameters under the distribution
    `density`. We do so for any pair of parameters specified in `subset`, thus
    creating a matrix containing conditional correlations between any pair of
    parameters.

    If `condition` is a batch of conditions, this function computes the conditional
    correlation matrix for each one of them and returns the mean.

    Args:
        density: Probability density function with `.log_prob()` function.
        limits: Limits within which to evaluate the `density`.
        condition: Values to condition the `density` on. If a batch of conditions is
            passed, we compute the conditional correlation matrix for each of them and
            return the average conditional correlation matrix.
        subset: Evaluate the conditional distribution only on a subset of dimensions.
            If `None` this function uses all dimensions.
        resolution: Number of grid points on which the conditional distribution is
            evaluated. A higher value increases the accuracy of the estimated
            correlation but also increases the computational cost.
        warn_about_deprecation: With sbi v0.15.0, we depracated the import of this
            function from `sbi.utils`. Instead, it should be imported from
            `sbi.analysis`.

    Returns: Average conditional correlation matrix of shape either `(num_dim, num_dim)`
    or `(len(subset), len(subset))` if `subset` was specified.
    """

    if warn_about_deprecation:
        warn(
            "Importing `conditional_corrcoeff` from `sbi.utils` is deprecated since "
            "sbi v0.15.0. Instead, use "
            "`from sbi.analysis import conditional_corrcoeff`.")

    condition = ensure_theta_batched(condition)

    if subset is None:
        subset = range(condition.shape[1])

    correlation_matrices = []
    for cond in condition:
        correlation_matrices.append(
            torch.stack([
                _compute_corrcoeff(
                    eval_conditional_density(
                        density,
                        cond,
                        limits,
                        dim1=dim1,
                        dim2=dim2,
                        resolution=resolution,
                    ),
                    limits[[dim1, dim2]],
                ) for dim1 in subset for dim2 in subset if dim1 < dim2
            ]))

    average_correlations = torch.mean(torch.stack(correlation_matrices), dim=0)

    # `average_correlations` is still a vector containing the upper triangular entries.
    # Below, assemble them into a matrix:
    av_correlation_matrix = torch.zeros((len(subset), len(subset)))
    triu_indices = torch.triu_indices(row=len(subset),
                                      col=len(subset),
                                      offset=1)
    av_correlation_matrix[triu_indices[0],
                          triu_indices[1]] = average_correlations

    # Make the matrix symmetric by copying upper diagonal to lower diagonal.
    av_correlation_matrix = torch.triu(av_correlation_matrix) + torch.tril(
        av_correlation_matrix.T)

    av_correlation_matrix.fill_diagonal_(1.0)
    return av_correlation_matrix
Ejemplo n.º 35
0
    def forward(self, tgt_seq, enc_output=None, category=None, signals=None, tags=None, **kwargs):
        decoding_type = kwargs.get('decoding_type', self.decoding_type)
        output_attentions = kwargs.get('output_attentions', False)

        if isinstance(enc_output, list):
            assert len(enc_output) == 1
            enc_output = enc_output[0]
        all_attentions = ()

        slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq)
        if decoding_type == 'NARFormer':
            slf_attn_mask = slf_attn_mask_keypad
        elif decoding_type == 'SelfMask':
            slf_attn_mask = slf_attn_mask_keypad
            seq_len = tgt_seq.size(1)
            
            diag =  torch.tril(torch.ones((seq_len, seq_len), device=slf_attn_mask.device, dtype=torch.uint8), diagonal=0) & \
                    torch.triu(torch.ones((seq_len, seq_len), device=slf_attn_mask.device, dtype=torch.uint8), diagonal=0)
            slf_attn_mask = (slf_attn_mask + diag).gt(0)

            # the i-th target can not see itself from the inputs
            '''
            tokens: <bos>   a       girl    is      singing <eos>
            target: a       girl    is      singing <eos>   ..
            '''
            #print(slf_attn_mask[0], slf_attn_mask.shape)
        else:
            slf_attn_mask_subseq = get_subsequent_mask(tgt_seq, watch=self.watch)
            slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)

        non_pad_mask = get_non_pad_mask(tgt_seq)
        src_seq = torch.ones(enc_output.size(0), enc_output.size(1)).to(enc_output.device)
        attend_to_enc_output_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq)

        additional_feats = None
        if decoding_type == 'NARFormer':
            if self.enhance_input == 0:
                pass
            elif self.enhance_input == 1:
                additional_feats = resampling(enc_output, tgt_seq)
            elif self.enhance_input == 2:
                additional_feats = enc_output.mean(1).unsqueeze(1).repeat(1, tgt_seq.size(1), 1)
            else:
                raise ValueError('enhance_input shoud be either 0, 1 or 2')
            
        if signals is not None:
            additional_feats = signals if additional_feats is None else (additional_feats + signals)

        if self.pos_attention:
            hidden_states, position_embeddings = self.embedding(tgt_seq, category=category)
        else:
            hidden_states = self.embedding(tgt_seq, additional_feats=additional_feats, category=category, tags=tags)
            position_embeddings = None

        res = []
        for i, layer_module in enumerate(self.layer):
            if not i:
                input_ = hidden_states
            else:
                input_ = layer_outputs[0]# + hidden_states
            
            layer_outputs = layer_module(
                input_, 
                non_pad_mask=non_pad_mask, 
                attention_mask=slf_attn_mask,
                enc_output=enc_output, 
                attend_to_enc_output_mask=attend_to_enc_output_mask, 
                position_embeddings=position_embeddings, 
                word_embeddings=self.get_word_embeddings(),
                **kwargs
            )

            res.append(layer_outputs[0])
            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[2],)
            embs = layer_outputs[1]

        res = [res[-1]]
        outputs = (res,embs,)

        if output_attentions:
            outputs = outputs + (all_attentions,)
        return outputs  # last-layer hidden state, (all hidden states), (all attentions)
    def forward(self, input_ids, attention_mask, segment_ids, start_positions,
                end_positions, answer_token_ids):

        batch_size = input_ids.shape[0]
        input_ids = input_ids.view(
            batch_size * (self.current_interaction_num + 1), -1)
        attention_mask = attention_mask.view(
            batch_size * (self.current_interaction_num + 1), -1)
        question_end_index = self._get_question_end_index(input_ids)
        # Each batch is one document, and each row of the batch is a chunck of the document.
        # Make sure all rows have the same question length.
        # assert (question_end_index[0].float() == question_end_index.float().mean()).item()
        # local attention everywhere, global attention on question
        tri = torch.tril(torch.ones([input_ids.shape[1], input_ids.shape[1]],
                                    dtype=torch.long,
                                    device=input_ids.device),
                         diagonal=-1)
        attention_mask = tri[question_end_index] + 1

        # sliding_chunks implemenation of selfattention requires that seqlen is multiple of window size
        input_ids, attention_mask = pad_to_window_size(
            input_ids, attention_mask, self.args.attention_window,
            self.tokenizer.pad_token_id)
        sequence_output = self.model.forward(input_ids,
                                             attention_mask=attention_mask)[0]
        sequence_output = sequence_output.view(
            batch_size, self.current_interaction_num + 1,
            sequence_output.shape[1], -1)
        p = (0, 0, 0, 0, 0,
             self.max_num_of_interactions - self.current_interaction_num)
        sequence_output = torch.nn.functional.pad(sequence_output,
                                                  p).permute(0, 2, 3, 1)
        weighted_sum = self.learned_weighted_sum(sequence_output)
        weighted_sum.squeeze_(-1)
        logits = self.qa_outputs(weighted_sum)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        outputs = (
            start_logits,
            end_logits,
        )
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # NOTE: this model predicts start and end index in the *original* question + context encoding.
            if not self.args.regular_softmax_loss:
                # loss function suggested in section 2.2 here https://arxiv.org/pdf/1710.10723.pdf
                # NOTE: this returns sum of losses, not mean, so loss won't be normalized across different batch sizes
                # but batch size is always 1, so this is not a problem
                start_loss = self.or_softmax_cross_entropy_loss_one_doc(
                    start_logits, start_positions, ignore_index=-1)
                end_loss = self.or_softmax_cross_entropy_loss_one_doc(
                    end_logits, end_positions, ignore_index=-1)
            else:
                loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-1)
                start_positions = start_positions[:, 0:1]
                end_positions = end_positions[:, 0:1]
                start_loss = loss_fct(start_logits, start_positions[:, 0])
                end_loss = loss_fct(end_logits, end_positions[:, 0])

            total_loss = (start_loss + end_loss) / 2
            outputs = (total_loss, ) + outputs
        return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)
Ejemplo n.º 37
0
def test_forward(args, learnable):
    args = make_args(**args)

    batch_size = 4
    klen = 40
    mlen = 20
    qlen = 5
    device = "cpu"

    key = torch.FloatTensor(batch_size, klen, args['kdim'], device=device)
    memory = torch.FloatTensor(batch_size, mlen, args['kdim'], device=device)
    query = torch.FloatTensor(batch_size, qlen, args['qdim'], device=device)

    # Create the self-attention mask
    causal_mask = torch.ones(qlen, klen + mlen, device=device).byte()
    causal_mask = torch.tril(causal_mask, diagonal=0 + mlen,
                             out=causal_mask).unsqueeze(0)
    causal_mask = causal_mask.repeat([batch_size, 1,
                                      1])  # `[B, qlen, klen+mlen]`

    module_embedding = importlib.import_module(
        'neural_sp.models.modules.positional_embedding')
    pos_emb = module_embedding.XLPositionalEmbedding(args['kdim'],
                                                     args['dropout'])

    if learnable:
        u = torch.nn.Parameter(
            torch.Tensor(args['n_heads'], args['adim'] // args['n_heads']))
        u = u.to(device)
        v = torch.nn.Parameter(
            torch.Tensor(args['n_heads'], args['adim'] // args['n_heads']))
        v = v.to(device)
    else:
        u, v = None, None

    module_mha = importlib.import_module(
        'neural_sp.models.modules.relative_multihead_attention')
    attention = module_mha.RelativeMultiheadAttentionMechanism(**args)
    attention = attention.to(device)

    attention.train()
    aws = None
    for i in range(qlen):
        pos_idxs = torch.arange(klen + mlen - 1,
                                -1,
                                -1.0,
                                dtype=torch.float,
                                device=device)
        pos_embs = pos_emb(pos_idxs)

        out = attention(key,
                        query[:, i:i + 1],
                        memory,
                        mask=causal_mask[:, i:i + 1],
                        pos_embs=pos_embs,
                        u=u,
                        v=v)
        assert len(out) == 2
        cv, aws = out
        assert cv.size() == (batch_size, 1, memory.size(2))
        assert aws.size() == (batch_size, args['n_heads'], 1, klen + mlen)
Ejemplo n.º 38
0
    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N , 1, trg_len, trg_len)

        return trg_mask.to(self.device)
Ejemplo n.º 39
0
 def _assemble_W(self):
     """ assemble W from its pieces (P, L, U, S) """
     L = torch.tril(self.L, diagonal=-1) + torch.diag(torch.ones(self.dim))
     U = torch.triu(self.U, diagonal=1)
     W = self.P @ L @ (U + torch.diag(self.S))
     return W
Ejemplo n.º 40
0
    def _sampling(self, B):
        """ generate adj in row-wise auto-regressive fashion """
        with torch.no_grad():

            K = self.block_size
            S = self.sample_stride
            H = self.hidden_dim
            N = self.max_num_nodes
            mod_val = (N - K) % S
            if mod_val > 0:
                N_pad = N - K - mod_val + int(np.ceil((K + mod_val) / S)) * S
            else:
                N_pad = N

            A = torch.zeros(B, N_pad, N_pad).to(self.device)
            dim_input = self.embedding_dim if self.dimension_reduce else self.max_num_nodes

            ### cache node state for speed up
            node_state = torch.zeros(B, N_pad, dim_input).to(self.device)

            for ii in range(0, N_pad, S):
                # for ii in range(0, 3530, S):
                jj = ii + K
                if jj > N_pad:
                    break

                # reset to discard overlap generation
                A[:, ii:, :] = .0
                A = torch.tril(A, diagonal=-1)

                if ii >= K:
                    if self.dimension_reduce:
                        node_state[:, ii - K:ii, :] = self.decoder_input(
                            A[:, ii - K:ii, :N])
                    else:
                        node_state[:, ii - K:ii, :] = A[:, ii - S:ii, :N]
                else:
                    if self.dimension_reduce:
                        node_state[:, :ii, :] = self.decoder_input(
                            A[:, :ii, :N])
                    else:
                        node_state[:, :ii, :] = A[:, ii - S:ii, :N]

                node_state_in = F.pad(node_state[:, :ii, :], (0, 0, 0, K),
                                      'constant',
                                      value=.0)

                ### GNN propagation
                adj = F.pad(A[:, :ii, :ii], (0, K, 0, K),
                            'constant',
                            value=1.0)  # B X jj X jj
                adj = torch.tril(adj, diagonal=-1)
                adj = adj + adj.transpose(1, 2)
                edges = [
                    adj[bb].to_sparse().coalesce().indices() +
                    bb * adj.shape[1] for bb in range(B)
                ]
                edges = torch.cat(edges, dim=1).t()

                att_idx = torch.cat(
                    [torch.zeros(ii).long(),
                     torch.arange(1, K + 1)]).to(self.device)
                att_idx = att_idx.view(1,
                                       -1).expand(B,
                                                  -1).contiguous().view(-1, 1)

                if self.has_rand_feat:
                    # create random feature
                    att_edge_feat = torch.zeros(
                        edges.shape[0], 2 * self.att_edge_dim).to(self.device)
                    idx_new_node = (att_idx[[edges[:, 0]]] > 0).long() + (
                        att_idx[[edges[:, 1]]] > 0).long()
                    idx_new_node = idx_new_node.byte().squeeze()
                    att_edge_feat[idx_new_node, :] = torch.randn(
                        idx_new_node.long().sum(),
                        att_edge_feat.shape[1]).to(self.device)
                else:
                    # create one-hot feature
                    att_edge_feat = torch.zeros(
                        edges.shape[0], 2 * self.att_edge_dim).to(self.device)
                    att_edge_feat = att_edge_feat.scatter(
                        1, att_idx[[edges[:, 0]]], 1)
                    att_edge_feat = att_edge_feat.scatter(
                        1, att_idx[[edges[:, 1]]] + self.att_edge_dim, 1)

                node_state_out = self.decoder(node_state_in.view(-1, H),
                                              edges,
                                              edge_feat=att_edge_feat)
                node_state_out = node_state_out.view(B, jj, -1)

                idx_row, idx_col = np.meshgrid(np.arange(ii, jj),
                                               np.arange(jj))
                idx_row = torch.from_numpy(idx_row.reshape(-1)).long().to(
                    self.device)
                idx_col = torch.from_numpy(idx_col.reshape(-1)).long().to(
                    self.device)

                diff = node_state_out[:,
                                      idx_row, :] - node_state_out[:,
                                                                   idx_col, :]  # B X (ii+K)K X H
                diff = diff.view(-1, node_state.shape[2])
                log_theta = self.output_theta(diff)
                log_alpha = self.output_alpha(diff)

                log_theta = log_theta.view(
                    B, -1, K, self.num_mix_component)  # B X K X (ii+K) X L
                log_theta = log_theta.transpose(1, 2)  # B X (ii+K) X K X L

                log_alpha = log_alpha.view(
                    B, -1, self.num_mix_component)  # B X K X (ii+K)
                prob_alpha = F.softmax(log_alpha.mean(dim=1), -1)
                alpha = torch.multinomial(prob_alpha, 1).squeeze(dim=1).long()

                prob = []
                for bb in range(B):
                    prob += [torch.sigmoid(log_theta[bb, :, :, alpha[bb]])]

                prob = torch.stack(prob, dim=0)
                A[:, ii:jj, :jj] = torch.bernoulli(prob[:, :jj - ii, :])

            ### make it symmetric
            if self.is_sym:
                A = torch.tril(A, diagonal=-1)
                A = A + A.transpose(1, 2)

            return A
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--train_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The input train corpus.")
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=3e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        "--on_memory",
        action='store_true',
        help="Whether to load train samples into memory or use disk")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument("--hybrid_attention",
                        action='store_true',
                        help="Whether to use hybrid attention")
    parser.add_argument("--continue_training",
                        action='store_true',
                        help="Continue training from a checkpoint")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train:
        raise ValueError(
            "Training is currently the only implemented execution option. Please set `do_train`."
        )

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and not args.continue_training:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    #train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        print("Loading Train Dataset", args.train_file)
        train_dataset = BERTDataset(args.train_file,
                                    tokenizer,
                                    seq_len=args.max_seq_length,
                                    corpus_lines=None,
                                    on_memory=args.on_memory)
        num_train_optimization_steps = int(
            len(train_dataset) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    model = BertForMaskedLM.from_pretrained(args.bert_model)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    if args.hybrid_attention:
        max_seq_length = args.max_seq_length
        attention_mask = torch.ones(12,
                                    max_seq_length,
                                    max_seq_length,
                                    dtype=torch.long)
        # left attention
        attention_mask[:2, :, :] = torch.tril(
            torch.ones(max_seq_length, max_seq_length, dtype=torch.long))
        # right attention
        attention_mask[2:4, :, :] = torch.triu(
            torch.ones(max_seq_length, max_seq_length, dtype=torch.long))
        # local attention, window size = 3
        attention_mask[4:6, :, :] = torch.triu(
            torch.tril(
                torch.ones(max_seq_length, max_seq_length, dtype=torch.long),
                1), -1)
        attention_mask = torch.cat(
            [attention_mask.unsqueeze(0) for _ in range(8)])
        attention_mask = attention_mask.to(device)
    else:
        attention_mask = None

    global_step = 0
    epoch_start = 0
    if args.do_train:
        if args.continue_training:
            # if checkpoint file exists, find the last checkpoint
            if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
                all_cp = os.listdir(args.output_dir)
                steps = [
                    int(re.search('_\d+', cp).group()[1:]) for cp in all_cp
                    if re.search('_\d+', cp)
                ]
                if len(steps) == 0:
                    raise ValueError(
                        "No existing checkpoint. Please do not use --continue_training."
                    )
                max_step = max(steps)
                # load checkpoint
                checkpoint = torch.load(
                    os.path.join(args.output_dir,
                                 'checkpoints_' + str(max_step) + '.pt'))
                logger.info("***** Loading checkpoint *****")
                logger.info("  Num steps = %d", checkpoint['global_step'])
                logger.info("  Num epoch = %d", checkpoint['epoch'])
                logger.info("  Loss = %d, %d", checkpoint['loss'],
                            checkpoint['loss_now'])
                model.module.load_state_dict(checkpoint['model'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                global_step = checkpoint['global_step']
                epoch_start = checkpoint['epoch']
                del checkpoint
            else:
                raise ValueError(
                    "No existing checkpoint. Please do not use --continue_training."
                )

        writer = SummaryWriter(log_dir=os.environ['HOME'])
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            #TODO: check if this works with current data generator from disk that relies on next(file)
            # (it doesn't return item back by index)
            train_sampler = DistributedSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        tr_loss_1000 = 0
        for ep in trange(epoch_start, int(args.num_train_epochs),
                         desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, lm_label_ids = batch
                loss = model(input_ids,
                             segment_ids,
                             input_mask,
                             lm_label_ids,
                             hybrid_mask=attention_mask)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                tr_loss += loss.item()
                tr_loss_1000 += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
                # log the training loss for every 1000 steps
                if global_step % 1000 == 999:
                    writer.add_scalar('data/loss', tr_loss_1000 / 1000,
                                      global_step)
                    logger.info("training steps: %s", global_step)
                    logger.info("training loss per 1000: %s",
                                tr_loss_1000 / 1000)
                    tr_loss_1000 = 0
                # save the checkpoint for every 10000 steps
                if global_step % 10000 == 0:
                    model_to_save = model.module if hasattr(
                        model,
                        'module') else model  # Only save the model it-self
                    output_file = os.path.join(
                        args.output_dir,
                        "checkpoints_" + str(global_step) + ".pt")
                    checkpoint = {
                        'model': model_to_save.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'epoch': ep,
                        'global_step': global_step,
                        'loss': tr_loss / nb_tr_steps,
                        'loss_now': tr_loss_1000
                    }
                    if args.do_train:
                        torch.save(checkpoint, output_file)
            model_to_save = model.module if hasattr(
                model, 'module') else model  # Only save the model it-self
            output_model_file = os.path.join(args.output_dir,
                                             "pytorch_model.bin_" + str(ep))
            if args.do_train:
                torch.save(model_to_save.state_dict(), output_model_file)
            logger.info("training loss: %s", tr_loss / nb_tr_steps)

        # Save a trained model
        logger.info("** ** * Saving fine - tuned model ** ** * ")
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
        if args.do_train:
            torch.save(model_to_save.state_dict(), output_model_file)
Ejemplo n.º 42
0
    def forward(self, z1ss, pos_emb, u1ss, mems=None):
        # Note: In this context, qlen means the length of the (small) subsequence; and mlen describes
        #       the length of the padding. Their sum is klen. 
        bsz, d_model, qlen = z1ss.size()
        r_w_bias, r_r_bias = self.r_w_bias, self.r_r_bias
        n_head, d_head = self.n_head, self.d_head
        rlen = pos_emb.size(2)
        
        if mems is None: 
            mems = torch.tensor([]).view(0,0,0)
        mlen = mems.size(2)
        cat = torch.cat([mems, z1ss], dim=-1)

        if self.pre_lnorm:
            cat = F.layer_norm(cat.transpose(1,2), (d_model,)).transpose(1,2)
        w_heads = self.qkv_net(cat)      # (N, 3C, L)
        r_head_k = self.r_net(pos_emb)

        # Input injection
        w_heads += u1ss
        w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=1)
        w_head_q = w_head_q[:,:,-qlen:]

        klen = w_head_k.size(2)

        w_head_q = w_head_q.view(bsz, n_head, d_head, qlen)           # bsz x n_head x d_head x qlen
        w_head_k = w_head_k.view(bsz, n_head, d_head, klen)           # bsz x n_head x d_head x klen
        w_head_v = w_head_v.view(bsz, n_head, d_head, klen)           # bsz x n_head x d_head x klen

        r_head_k = r_head_k.view(n_head, d_head, rlen)                # n_head x d_head x rlen

        #### compute attention score
        rw_head_q = w_head_q + r_w_bias[:,:,None]                   # bsz x n_head x d_head x qlen
        AC = torch.einsum('bndi,bndj->bnij', rw_head_q, w_head_k)
        rr_head_q = w_head_q + r_r_bias[:,:,None]
        BD = torch.einsum('bndi,ndj->bnij', rr_head_q, r_head_k)
        BD = self._rel_shift(BD)    # for the sake of relative positional embedding

        attn_score = AC + BD        # bsz x n_head x qlen x klen
        attn_score.mul_(self.scale)
            
        #### compute attention probability
        # We apply a local mask, with local horizon size of mlen
        local_size = self.local_size or 1000
        attn_mask = torch.triu(torch.ones(qlen, klen), diagonal=1+mlen).byte()[None,:,:]
        attn_mask += torch.tril(torch.ones(qlen, klen), diagonal=mlen-local_size).byte()[None,:,:]
        if attn_mask is not None and attn_mask.any().item():
            attn_score = attn_score.float().masked_fill(
                    attn_mask[None,:,:,:], -float('inf')).type_as(attn_score)
                
        attn_prob = F.softmax(attn_score, dim=-1)          # bsz x n_head x qlen x klen
            
        #### compute attention vector
        attn_vec = torch.einsum('bnij,bndj->bndi', (attn_prob, w_head_v))
        
        # [bsz x d x qlen]
        attn_vec = attn_vec.contiguous().view(bsz, n_head*d_head, attn_vec.size(-1))

        ##### linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)
        
        ##### residual connection + layer normolization (if applicable)
        if self.pre_lnorm:
            out = attn_out + z1ss
        else:
            out = F.layer_norm((attn_out + z1ss).transpose(1,2), (d_model,)).transpose(1,2)
        return out
def subsequent_mask(size: int, device: str = 'cpu') -> torch.Tensor:
    """Mask out subsequent positions."""
    mask = torch.tril(torch.ones(size, size, device=device, dtype=torch.int32)).unsqueeze(0)
    return mask