Exemplo n.º 1
0
    def forward(self, x, hx=None):
        x, batch_sizes = x
        batch_size = batch_sizes[0]

        if hx is None:
            init = x.new_zeros(batch_size, self.hidden_size)
            hx = (init, init)

        for layer in range(self.num_layers):
            if self.training:
                mask = SharedDropout.get_mask(x[:batch_size], self.dropout)
                mask = torch.cat(
                    [mask[:batch_size] for batch_size in batch_sizes])
                x *= mask
            x = torch.split(x, batch_sizes.tolist())
            f_output = self.layer_forward(x=x,
                                          hx=hx,
                                          cell=self.f_cells[layer],
                                          batch_sizes=batch_sizes,
                                          reverse=False)
            b_output = self.layer_forward(x=x,
                                          hx=hx,
                                          cell=self.b_cells[layer],
                                          batch_sizes=batch_sizes,
                                          reverse=True)
            x = torch.cat([f_output, b_output], -1)
        x = PackedSequence(x, batch_sizes)

        return x
Exemplo n.º 2
0
    def layer_forward(self, x, hx, cell, batch_sizes, reverse=False):
        h, c = hx
        init_h, init_c = h, c
        output, seq_len = [], len(x)
        steps = reversed(range(seq_len)) if reverse else range(seq_len)
        if self.training:
            hid_mask = SharedDropout.get_mask(h, self.dropout)

        for t in steps:
            last_batch_size, batch_size = len(h), batch_sizes[t]
            if last_batch_size < batch_size:
                h = torch.cat((h, init_h[last_batch_size:batch_size]))
                c = torch.cat((c, init_c[last_batch_size:batch_size]))
            else:
                h = h[:batch_size]
                c = c[:batch_size]
            h, c = cell(input=x[t], hx=(h, c))
            output.append(h)
            if self.training:
                h = h * hid_mask[:batch_size]
        if reverse:
            output.reverse()
        output = torch.cat(output)

        return output
Exemplo n.º 3
0
    def forward(self, x, y):
        if self.bias_x:
            x = torch.cat([x, x.new_ones(x.shape[:-1]).unsqueeze(-1)], -1)
        if self.bias_y:
            y = torch.cat([y, y.new_ones(y.shape[:-1]).unsqueeze(-1)], -1)
        # [batch_size, 1, seq_len, d]
        x = x.unsqueeze(1)
        # [batch_size, 1, seq_len, d]
        y = y.unsqueeze(1)
        # [batch_size, n_out, seq_len, seq_len]
        s = x @ self.weight @ y.transpose(-1, -2)
        # remove dim 1 if n_out == 1
        s = s.squeeze(1)

        return s
Exemplo n.º 4
0
    def forward(self,
                hidden_states,
                start_states=None,
                start_positions=None,
                cls_index=None):
        """
        Args:
            One of ``start_states``, ``start_positions`` should be not None.
            If both are set, ``start_positions`` overrides ``start_states``.

            **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
                hidden states of the first tokens for the labeled span.
            **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
                position of the first token for the labeled span.
            **cls_index**: torch.LongTensor of shape ``(batch_size,)``
                position of the CLS token. If None, take the last token.

            note(Original repo):
                no dependency on end_feature so that we can obtain one single `cls_logits`
                for each sample
        """
        hsz = hidden_states.shape[-1]
        assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
        if start_positions is not None:
            start_positions = start_positions[:, None, None].expand(
                -1, -1, hsz)  # shape (bsz, 1, hsz)
            start_states = hidden_states.gather(-2, start_positions).squeeze(
                -2)  # shape (bsz, hsz)

        if cls_index is not None:
            cls_index = cls_index[:, None,
                                  None].expand(-1, -1,
                                               hsz)  # shape (bsz, 1, hsz)
            cls_token_state = hidden_states.gather(-2, cls_index).squeeze(
                -2)  # shape (bsz, hsz)
        else:
            cls_token_state = hidden_states[:, -1, :]  # shape (bsz, hsz)

        x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
        x = self.activation(x)
        x = self.dense_1(x).squeeze(-1)

        return x
Exemplo n.º 5
0
    def forward(self,
                hidden_states,
                start_states=None,
                start_positions=None,
                p_mask=None):
        """ Args:
            One of ``start_states``, ``start_positions`` should be not None.
            If both are set, ``start_positions`` overrides ``start_states``.

            **start_states**: ``torch.LongTensor`` of shape identical to hidden_states
                hidden states of the first tokens for the labeled span.
            **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
                position of the first token for the labeled span:
            **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
                Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
                1.0 means token should be masked.
        """
        assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
        if start_positions is not None:
            slen, hsz = hidden_states.shape[-2:]
            start_positions = start_positions[:, None, None].expand(
                -1, -1, hsz)  # shape (bsz, 1, hsz)
            start_states = hidden_states.gather(
                -2, start_positions)  # shape (bsz, 1, hsz)
            start_states = start_states.expand(-1, slen,
                                               -1)  # shape (bsz, slen, hsz)

        x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
        x = self.activation(x)
        x = self.LayerNorm(x)
        x = self.dense_1(x).squeeze(-1)

        if p_mask is not None:
            if next(self.parameters()).dtype == torch.float16:
                x = x * (1 - p_mask) - 65500 * p_mask
            else:
                x = x * (1 - p_mask) - 1e30 * p_mask

        return x
        "create a mask to hide paddding and future work"
        mask = (output_seq != padding).unsqueeze(-2)
        mask = mask & Variable(
            subsequent_mask(output_seq.size(-1)).type_as(mask.data))
        return mask


# test

if __name__ == "__main__":

    plt.figure(figsize=(5, 5))
    plt.imshow(subsequent_mask(20)[0])
    #plt.show()
    data_out = torch.cat((torch.empty(1, 4, dtype=torch.long).random_(
        2, 5), torch.ones(1, 4, dtype=torch.long)),
                         dim=1)
    data_out = torch.cat((data_out, data_out), dim=0)
    data_in = torch.cat((torch.empty(1, 4, dtype=torch.long).random_(
        2, 4), torch.zeros(1, 3, dtype=torch.long)),
                        dim=1)
    data_in[:, 0] = 2
    data_out[:, 0] = 2
    #data_in = torch.cat((data_in, data_in), dim=0)
    #data_out = data_out.unsqueeze(0)
    #data_in = data_in.unsqueeze(0)
    print("DATA IN {} {} ".format(data_in, data_in.size()))
    print("DATA OUT {} {} ".format(data_out, data_out.size()))
    batch = MaskBatch(data_in, data_out, pad=1, verbose=5)
    print("INPUT MASK {} , output mask {} ".format(batch.input_seq_mask,
                                                   batch.output_mask))
def get_hidden_representation(data,
                              model,
                              tokenizer,
                              special_start="[CLS]",
                              special_end="[SEP]",
                              pad="[PAD]",
                              max_len=100,
                              pad_below_max_len=False,
                              output_dic=True):
    """
    get hidden representation (ie contetualized vector at the word level : add it as list or padded tensor : output[attention|layer|layer_head]["layer_x"] list or tensor)
    :param data: list of raw text
    :param pad: will add padding below max_len
    :return: output a dictionary (if output_dic) or a tensor (if not output_dic) : of contextualized representation at the word level per layer/layer_head
    """
    model.eval()
    special_start = tokenizer.bos_token
    special_end = tokenizer.eos_token
    if special_start is None or special_end is None:
        special_start = "[CLS]"
        special_end = "[SEP]"

    layer_head_att_tensor_dic = OrderedDict()
    layer_hidden_state_tensor_dic = OrderedDict()
    layer_head_hidden_state_tensor_dic = OrderedDict()
    layer_head_att_batch_dic = OrderedDict()
    layer_head_hidden_state_dic = OrderedDict()
    layer_hidden_state_dic = OrderedDict()
    print(
        f"Getting hidden representation : adding special char start:{special_start} end:{special_end}"
    )
    for seq in data:
        seq = special_start + " " + seq + " " + special_end
        tokenized = tokenizer.encode(seq)
        if len(tokenized) >= max_len:
            tokenized = tokenized[:max_len - 1]
            tokenized += tokenizer.encode(special_end)
        mask = [1 for _ in range(len(tokenized))]
        real_len = len(tokenized)
        if pad_below_max_len:
            if len(tokenized) < max_len:
                for _ in range(max_len - len(tokenized)):
                    tokenized += tokenizer.encode(pad)
                    mask.append(0)
            assert len(tokenized) == max_len
        assert len(tokenized) <= max_len + 2

        encoded = torch.tensor(tokenized).unsqueeze(0)
        inputs = OrderedDict([("wordpieces_inputs_words", encoded)])
        attention_mask = OrderedDict([("wordpieces_inputs_words",
                                       torch.tensor(mask).unsqueeze(0))])
        assert real_len
        if torch.cuda.is_available():
            inputs["wordpieces_inputs_words"] = inputs[
                "wordpieces_inputs_words"].cuda()

            attention_mask["wordpieces_inputs_words"] = attention_mask[
                "wordpieces_inputs_words"].cuda()
        model_output = model(input_ids_dict=inputs,
                             attention_mask=attention_mask)
        #pdb.set_trace()
        #logits = model_output[0]

        # getting the output index based on what we are asking the model
        hidden_state_per_layer_index = 2 if model.config.output_hidden_states else False
        attention_index_original_index = 3 - int(
            not hidden_state_per_layer_index
        ) if model.config.output_attentions else False
        hidden_state_per_layer_per_head_index = False  #4-int(not attention_index_original_index) if model.config.output_hidden_states_per_head else False
        # getting the output
        hidden_state_per_layer = model_output[
            hidden_state_per_layer_index] if hidden_state_per_layer_index else None
        attention = model_output[
            attention_index_original_index] if attention_index_original_index else None
        hidden_state_per_layer_per_head = model_output[
            hidden_state_per_layer_per_head_index] if hidden_state_per_layer_per_head_index else None

        # checking that we got the correct output
        try:
            if attention is not None:
                assert len(attention) == 12, "ERROR attenttion"
                assert attention[0].size()[-1] == attention[0].size(
                )[-2], "ERROR attenttion"
            if hidden_state_per_layer is not None:
                assert len(
                    hidden_state_per_layer) == 12 + 1, "ERROR hidden state"
                assert hidden_state_per_layer[0].size(
                )[-1] == 768, "ERROR hidden state"
            if hidden_state_per_layer_per_head is not None:
                assert len(hidden_state_per_layer_per_head
                           ) == 12, "ERROR hidden state per layer"
                assert hidden_state_per_layer_per_head[0].size(
                )[1] == 12 and hidden_state_per_layer_per_head[0].size(
                )[-1] == 64, "ERROR hidden state per layer"
        except Exception as e:
            raise (Exception(e))

        # concat as a batch per layer/layer_head
        if hidden_state_per_layer is not None:
            layer_hidden_state_dic = get_batch_per_layer_head(
                hidden_state_per_layer, layer_hidden_state_dic, head=False)
        if attention is not None:
            layer_head_att_batch_dic = get_batch_per_layer_head(
                attention, layer_head_att_batch_dic)
        if hidden_state_per_layer_per_head is not None:
            layer_head_hidden_state_dic = get_batch_per_layer_head(
                hidden_state_per_layer_per_head, layer_head_hidden_state_dic)

    output = ()
    if output_dic:
        if len(layer_hidden_state_dic) > 0:
            output = output + (layer_hidden_state_dic, )
        if len(layer_head_att_batch_dic) > 0:
            output = output + (layer_head_att_batch_dic, )
        if len(layer_head_hidden_state_dic) > 0:
            output = output + (layer_head_hidden_state_dic, )
    else:
        # concatanate in a tensor
        # should have padding on !
        assert pad_below_max_len
        if len(layer_hidden_state_dic) > 0:
            for key in layer_hidden_state_dic:
                layer_hidden_state_tensor_dic[key] = torch.cat(
                    layer_hidden_state_dic[key], 0)
            output = output + (layer_hidden_state_tensor_dic, )
        if len(layer_head_att_batch_dic) > 0:
            for key in layer_head_att_batch_dic:
                layer_head_att_tensor_dic[key] = torch.cat(
                    layer_head_att_batch_dic[key], 0)
            output = output + (layer_head_att_tensor_dic, )
        if len(layer_head_hidden_state_dic) > 0:
            for key in layer_head_hidden_state_dic:
                layer_head_hidden_state_tensor_dic[key] = torch.cat(
                    layer_head_hidden_state_dic[key], 0)
            output = output + (layer_head_hidden_state_tensor_dic, )

    return output