class ParaNNTranSegmentor(nn.Module):
    """
    ParaNNTranSegmentor
    """

    def __init__(self, pretra_char_embed, char_embed_num, char_embed_dim, char_embed_dim_no_static, char_embed_max_norm,
                 pretra_bichar_embed, bichar_embed_num, bichar_embed_dim, bichar_embed_dim_no_static,
                 bichar_embed_max_norm, dropout_embed, encoder_embed_dim, dropout_encoder_embed, encoder_lstm_hid_size,
                 dropout_encoder_hid, subword_lstm_hid_size, word_lstm_hid_size, device):
        super(ParaNNTranSegmentor, self).__init__()

        self.char_encoder = CharEncoder(pretra_char_embed, char_embed_num, char_embed_dim, char_embed_dim_no_static,
                                        char_embed_max_norm, pretra_bichar_embed, bichar_embed_num, bichar_embed_dim,
                                        bichar_embed_dim_no_static, bichar_embed_max_norm, dropout_embed,
                                        encoder_embed_dim, dropout_encoder_embed, encoder_lstm_hid_size,
                                        dropout_encoder_hid, device)
        self.subwStackLSTM = SubwordLSTMCell(encoder_lstm_hid_size*2, subword_lstm_hid_size, device)
        self.wordStackLSTM = StackLSTMCell(2*subword_lstm_hid_size, word_lstm_hid_size, device)
        self.classifier = nn.Linear(word_lstm_hid_size+2*encoder_lstm_hid_size, 2, bias=True)
        self.device = device
        self.subword_action_map = torch.tensor([1, 2, 2]).to(self.device)
        self.word_action_map = torch.tensor([0, 1, 1]).to(self.device)
        self.__init_para()

    def forward(self, insts, golds=None):
        chars = self.char_encoder(insts)  # (seq_len, batch_size, encoder_lstm_hid_size*2)

        batch_size, seq_len = insts[0].shape[0], insts[0].shape[1]
        self.subwStackLSTM.init_stack(2*seq_len+2, batch_size)
        self.wordStackLSTM.init_stack(seq_len+1, batch_size)

        pred = [torch.tensor([[[-1., 1.]]]).expand((batch_size, 1, 2)).to(self.device)]
        for idx in range(1, seq_len, 1):
            subword = self.subwStackLSTM(chars[idx - 1, :, :])  # (batch_size, subword_lstm_hid_size)
            word_repre, _ = self.wordStackLSTM(subword)  # (batch_size, word_lstm_hid_size)
            output = self.classifier(torch.cat([word_repre, chars[idx, :, :]], 1))  # (batch_size, 2)

            if self.training:
                subwordOP = self.subword_action_map.index_select(0, golds[:, idx])
                wordOP = self.word_action_map.index_select(0, golds[:, idx])
            else:
                subwordOP = self.subword_action_map.index_select(0, torch.argmax(output, 1))
                wordOP = self.word_action_map.index_select(0, torch.argmax(output, 1))
            self.subwStackLSTM.update_pos(subwordOP)
            self.wordStackLSTM.update_pos(wordOP)

            pred.append(output.unsqueeze(1))
        return torch.cat(pred, 1)  # (batch_size, seq_len, 2)

    def __init_para(self):
        nn.init.xavier_uniform_(self.classifier.weight)
        nn.init.uniform_(self.classifier.bias)
class SubwordLSTMCell(nn.Module):
    """
    SubwordStackLSTMCell class
    """
    def __init__(self, input_size, hidden_size, device):
        super(SubwordLSTMCell, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.device = device
        self.lstm_r = StackLSTMCell(self.input_size, self.hidden_size, self.device)
        self.lstm_l = nn.LSTMCell(self.input_size, self.hidden_size, bias=True)
        # self.word_compose = nn.Sequential(
        #     nn.Linear(2*self.hidden_size, 2*self.hidden_size),
        #     nn.Tanh()
        # )
        self.__init_para()

    def init_stack(self, stack_size, batch_size):
        """
        init the stack
        :param stack_size: int, equals 2*seq_len+2, means the max size of stack.
        :param batch_size: int, the number of inst for a batch.
        :return:
        """
        self.lstm_r.init_stack(stack_size, batch_size)

    def forward(self, char):
        """
        generate subword/word representation by biLSTM and non-linear layer.
        :param char: (batch_size, self.input_size), output of char_encoder.
        :return: subword/word representation
        """
        h, _ = self.lstm_r(char)
        h_l, _ = self.lstm_l(char)
        # subword = self.word_compose(torch.cat([h, h_l], 1))
        # return subword
        return torch.cat([h, h_l], 1)

    def update_pos(self, op):
        """
        update the self.pos_word and self.pos_char depending on operation.
        :param op: (batch_size, ) -1 means pop, 0 means hold, 1 means push.
        :return:
        """
        self.lstm_r.update_pos(op)

    def __init_para(self):
        init.xavier_uniform_(self.lstm_l.weight_hh)
        init.xavier_uniform_(self.lstm_l.weight_ih)
        init.uniform_(self.lstm_l.bias_hh)
        init.uniform_(self.lstm_l.bias_ih)
    def __init__(self, input_size, hidden_size, device):
        super(SubwordLSTMCell, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.device = device
        self.lstm_r = StackLSTMCell(self.input_size, self.hidden_size, self.device)
        self.lstm_l = nn.LSTMCell(self.input_size, self.hidden_size, bias=True)
        # self.word_compose = nn.Sequential(
        #     nn.Linear(2*self.hidden_size, 2*self.hidden_size),
        #     nn.Tanh()
        # )
        self.__init_para()
Beispiel #4
0
    def __init__(self, device):
        super(BertStackSegmentor, self).__init__()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.bert_model = BertModel.from_pretrained('bert-base-chinese')

        self.lstm = nn.LSTM(768, 768, batch_first=True, bidirectional=True)
        self.subwStackLSTM = SubwordLSTMCell(768 * 2, 768, device)
        self.wordStackLSTM = StackLSTMCell(2 * 768, 768, device)
        self.cls = nn.Linear(768 * 3, 2)
        self.device = device
        self.subword_action_map = torch.tensor([1, 2, 2]).to(self.device)
        self.word_action_map = torch.tensor([0, 1, 1]).to(self.device)
        self.__init_para()
    def __init__(self, pretra_char_embed, char_embed_num, char_embed_dim, char_embed_dim_no_static, char_embed_max_norm,
                 pretra_bichar_embed, bichar_embed_num, bichar_embed_dim, bichar_embed_dim_no_static,
                 bichar_embed_max_norm, dropout_embed, encoder_embed_dim, dropout_encoder_embed, encoder_lstm_hid_size,
                 dropout_encoder_hid, subword_lstm_hid_size, word_lstm_hid_size, device):
        super(ParaNNTranSegmentor, self).__init__()

        self.char_encoder = CharEncoder(pretra_char_embed, char_embed_num, char_embed_dim, char_embed_dim_no_static,
                                        char_embed_max_norm, pretra_bichar_embed, bichar_embed_num, bichar_embed_dim,
                                        bichar_embed_dim_no_static, bichar_embed_max_norm, dropout_embed,
                                        encoder_embed_dim, dropout_encoder_embed, encoder_lstm_hid_size,
                                        dropout_encoder_hid, device)
        self.subwStackLSTM = SubwordLSTMCell(encoder_lstm_hid_size*2, subword_lstm_hid_size, device)
        self.wordStackLSTM = StackLSTMCell(2*subword_lstm_hid_size, word_lstm_hid_size, device)
        self.classifier = nn.Linear(word_lstm_hid_size+2*encoder_lstm_hid_size, 2, bias=True)
        self.device = device
        self.subword_action_map = torch.tensor([1, 2, 2]).to(self.device)
        self.word_action_map = torch.tensor([0, 1, 1]).to(self.device)
        self.__init_para()
Beispiel #6
0
  def __init__(self, vocab, actions, options, pre_vocab=None, postags=None):
    """
    :param vocab: vocabulary, with type torchtext.vocab.Vocab
    :param actions: a python list of actions, denoted as "transition" (unlabled) or "transition|label" (labeled).
           e.g. "Left-Arc", "Left-Arc|nsubj"
    :param pre_vocab: vocabulary with pre-trained word embedding, with type torchtext.vocab.Vocab.
           This vocabulary should contain the word embedding itself as well, see torchtext documents for details.
    :param postags: a python list of pos, in strings (optional)
    """
    super(StackLSTMParser, self).__init__()
    self.stack_size = options.stack_size
    self.hid_dim = options.hid_dim
    self.input_dim = options.input_dim
    self.max_step_length = options.max_step_length
    self.transSys = TransitionSystems(options.transSys)
    self.exposure_eps = options.exposure_eps
    self.num_lstm_layers = options.num_lstm_layers

    # gpu
    self.gpuid = options.gpuid
    self.dtype = torch.cuda.FloatTensor if len(options.gpuid) >= 1 else torch.FloatTensor
    self.long_dtype = torch.cuda.LongTensor if len(options.gpuid) >= 1 else torch.LongTensor

    # vocabularies
    self.vocab = vocab
    self.pre_vocab = pre_vocab
    self.actions = actions
    self.postags = postags

    # action mappings
    self.stack_action_mapping = torch.zeros(len(actions),).type(self.long_dtype)
    self.buffer_action_mapping = torch.zeros(len(actions),).type(self.long_dtype)
    self.set_action_mappings()

    # embeddings
    self.word_emb = nn.Embedding(len(vocab), options.word_emb_dim)
    self.action_emb = nn.Embedding(len(actions), options.action_emb_dim)

    if pre_vocab is not None:
      self.pre_word_emb = nn.Embedding(pre_vocab.vectors.size(0), pre_vocab.dim)
      # initialzie and fixed
      self.pre_word_emb.weight = nn.Parameter(pre_vocab.vectors)
      self.pre_word_emb.weight.requires_grad = False
    else:
      self.pre_word_emb = None

    if postags is not None:
      self.postag_emb = nn.Embedding(len(postags), options.postag_emb_dim)
    else:
      self.postag_emb = None

    # token embdding composition
    token_comp_in_features = options.word_emb_dim # vocab only
    if pre_vocab is not None:
      # token_comp_in_features += options.pre_word_emb_dim
      token_comp_in_features += pre_vocab.dim
    if postags is not None:
      token_comp_in_features += options.postag_emb_dim
    self.compose_tokens = nn.Sequential(
        nn.Linear(token_comp_in_features, options.input_dim),
        nn.ReLU()
    )
    self.root_input = nn.Parameter(torch.rand(options.input_dim))

    # recurrent components
    # the only reason to have unified h0 and c0 is because pre_buffer and buffer should have the same initial states
    # but due to simplicity we just borrowed this for all three recurrent components (stack, buffer, action)
    self.h0 = nn.Parameter(torch.rand(options.hid_dim,).type(self.dtype))
    self.c0 = nn.Parameter(torch.rand(options.hid_dim,).type(self.dtype))
    # FIXME: there is no dropout in StackLSTMCell at this moment
    # BufferLSTM could have 0 or 2 parameters, depending on what is passed for initial hidden and cell state
    self.stack = StackLSTMCell(options.input_dim, options.hid_dim, options.stack_size, options.num_lstm_layers, self.h0, self.c0)
    self.buffer = StateStack(options.hid_dim, self.h0)
    self.token_stack = StateStack(options.input_dim)
    self.token_buffer = StateStack(options.input_dim) # elememtns in this buffer has size input_dim so h0 and c0 won't fit
    self.pre_buffer = MultiLayerLSTMCell(input_size=options.input_dim, hidden_size=options.hid_dim, num_layers=options.num_lstm_layers) # FIXME: dropout needs to be implemented manually
    self.history = MultiLayerLSTMCell(input_size=options.action_emb_dim, hidden_size=options.hid_dim, num_layers=options.num_lstm_layers) # FIXME: dropout needs to be implemented manually

    # parser state and softmax
    self.summarize_states = nn.Sequential(
        nn.Linear(3 * options.hid_dim, options.state_dim),
        nn.ReLU()
    )
    self.state_to_actions = nn.Linear(options.state_dim, len(actions))
Beispiel #7
0
class StackLSTMParser(nn.Module):

  def __init__(self, vocab, actions, options, pre_vocab=None, postags=None):
    """
    :param vocab: vocabulary, with type torchtext.vocab.Vocab
    :param actions: a python list of actions, denoted as "transition" (unlabled) or "transition|label" (labeled).
           e.g. "Left-Arc", "Left-Arc|nsubj"
    :param pre_vocab: vocabulary with pre-trained word embedding, with type torchtext.vocab.Vocab.
           This vocabulary should contain the word embedding itself as well, see torchtext documents for details.
    :param postags: a python list of pos, in strings (optional)
    """
    super(StackLSTMParser, self).__init__()
    self.stack_size = options.stack_size
    self.hid_dim = options.hid_dim
    self.input_dim = options.input_dim
    self.max_step_length = options.max_step_length
    self.transSys = TransitionSystems(options.transSys)
    self.exposure_eps = options.exposure_eps
    self.num_lstm_layers = options.num_lstm_layers

    # gpu
    self.gpuid = options.gpuid
    self.dtype = torch.cuda.FloatTensor if len(options.gpuid) >= 1 else torch.FloatTensor
    self.long_dtype = torch.cuda.LongTensor if len(options.gpuid) >= 1 else torch.LongTensor

    # vocabularies
    self.vocab = vocab
    self.pre_vocab = pre_vocab
    self.actions = actions
    self.postags = postags

    # action mappings
    self.stack_action_mapping = torch.zeros(len(actions),).type(self.long_dtype)
    self.buffer_action_mapping = torch.zeros(len(actions),).type(self.long_dtype)
    self.set_action_mappings()

    # embeddings
    self.word_emb = nn.Embedding(len(vocab), options.word_emb_dim)
    self.action_emb = nn.Embedding(len(actions), options.action_emb_dim)

    if pre_vocab is not None:
      self.pre_word_emb = nn.Embedding(pre_vocab.vectors.size(0), pre_vocab.dim)
      # initialzie and fixed
      self.pre_word_emb.weight = nn.Parameter(pre_vocab.vectors)
      self.pre_word_emb.weight.requires_grad = False
    else:
      self.pre_word_emb = None

    if postags is not None:
      self.postag_emb = nn.Embedding(len(postags), options.postag_emb_dim)
    else:
      self.postag_emb = None

    # token embdding composition
    token_comp_in_features = options.word_emb_dim # vocab only
    if pre_vocab is not None:
      # token_comp_in_features += options.pre_word_emb_dim
      token_comp_in_features += pre_vocab.dim
    if postags is not None:
      token_comp_in_features += options.postag_emb_dim
    self.compose_tokens = nn.Sequential(
        nn.Linear(token_comp_in_features, options.input_dim),
        nn.ReLU()
    )
    self.root_input = nn.Parameter(torch.rand(options.input_dim))

    # recurrent components
    # the only reason to have unified h0 and c0 is because pre_buffer and buffer should have the same initial states
    # but due to simplicity we just borrowed this for all three recurrent components (stack, buffer, action)
    self.h0 = nn.Parameter(torch.rand(options.hid_dim,).type(self.dtype))
    self.c0 = nn.Parameter(torch.rand(options.hid_dim,).type(self.dtype))
    # FIXME: there is no dropout in StackLSTMCell at this moment
    # BufferLSTM could have 0 or 2 parameters, depending on what is passed for initial hidden and cell state
    self.stack = StackLSTMCell(options.input_dim, options.hid_dim, options.stack_size, options.num_lstm_layers, self.h0, self.c0)
    self.buffer = StateStack(options.hid_dim, self.h0)
    self.token_stack = StateStack(options.input_dim)
    self.token_buffer = StateStack(options.input_dim) # elememtns in this buffer has size input_dim so h0 and c0 won't fit
    self.pre_buffer = MultiLayerLSTMCell(input_size=options.input_dim, hidden_size=options.hid_dim, num_layers=options.num_lstm_layers) # FIXME: dropout needs to be implemented manually
    self.history = MultiLayerLSTMCell(input_size=options.action_emb_dim, hidden_size=options.hid_dim, num_layers=options.num_lstm_layers) # FIXME: dropout needs to be implemented manually

    # parser state and softmax
    self.summarize_states = nn.Sequential(
        nn.Linear(3 * options.hid_dim, options.state_dim),
        nn.ReLU()
    )
    self.state_to_actions = nn.Linear(options.state_dim, len(actions))


  def forward(self, tokens, tokens_mask, pre_tokens=None, postags=None, actions=None):
    """

    :param tokens: (seq_len, batch_size) tokens of the input sentence, preprocessed as vocabulary indexes
    :param tokens_mask: (seq_len, batch_size) indication of whether this is a "real" token or a padding
    :param postags: (seq_len, batch_size) optional POS-tag of the tokens of the input sentence, preprocessed as indexes
    :param actions: (seq_len, batch_size) optional golden action sequence, preprocessed as indexes
    :return: output: (output_seq_len, batch_size, len(actions)), or when actions = None, (output_seq_len, batch_size, max_step_length)
    """

    seq_len = tokens.size()[0]
    batch_size = tokens.size()[1]
    word_emb = self.word_emb(tokens.t()) # (batch_size, seq_len, word_emb_dim)
    token_comp_input = word_emb
    if self.pre_word_emb is not None:
      pre_word_emb = self.pre_word_emb(pre_tokens.t()) # (batch_size, seq_len, self.pre_vocab.dim)
      token_comp_input = torch.cat((token_comp_input, pre_word_emb), dim=-1)
    if self.postag_emb is not None and postags is not None:
      pos_tag_emb = self.postag_emb(postags.t()) # (batch_size, seq_len, word_emb_dim)
      token_comp_input = torch.cat((token_comp_input, pos_tag_emb), dim=-1)
    token_comp_output = self.compose_tokens(token_comp_input) # (batch_size, seq_len, input_dim), pad at end
    token_comp_output = torch.transpose(token_comp_output, 0, 1) # (seq_len, batch_size, input_dim), pad at end
    token_comp_output_rev = utils.tensor.masked_revert(token_comp_output, tokens_mask.unsqueeze(2).expand(seq_len, batch_size, self.input_dim), 0) # (seq_len, batch_size, input_dim), pad at end

    # initialize stack
    self.stack.build_stack(batch_size, self.gpuid)

    # initialize and preload buffer
    buffer_hiddens = torch.zeros((seq_len, batch_size, self.hid_dim)).type(self.dtype)
    buffer_cells = torch.zeros((seq_len, batch_size, self.hid_dim)).type(self.dtype)
    bh = self.h0.unsqueeze(0).unsqueeze(2).expand(batch_size, self.hid_dim, self.num_lstm_layers)
    bc = self.c0.unsqueeze(0).unsqueeze(2).expand(batch_size, self.hid_dim, self.num_lstm_layers)
    for t_i in range(seq_len):
      bh, bc = self.pre_buffer(token_comp_output_rev[t_i], (bh, bc)) # (batch_size, self.hid_dim, self.num_lstm_layers)
      buffer_hiddens[t_i, :, :] = bh[:, :, -1]
      buffer_cells[t_i, :, :] = bc[:, :, -1]

    self.buffer.build_stack(batch_size, seq_len, buffer_hiddens, tokens_mask, self.gpuid)
    self.token_stack.build_stack(batch_size, self.stack_size, gpuid=self.gpuid)
    self.token_buffer.build_stack(batch_size, seq_len, token_comp_output_rev, tokens_mask, self.gpuid)
    self.stack(self.root_input.unsqueeze(0).expand(batch_size, self.input_dim), torch.ones(batch_size).type(self.long_dtype)) # push root

    stack_state, _ = self.stack.head() # (batch_size, hid_dim)
    buffer_state = self.buffer.head() # (batch_size, hid_dim)
    token_stack_state = self.token_stack.head()
    token_buffer_state = self.token_buffer.head()
    stack_input =  token_buffer_state # (batch_size, input_size)
    action_state = self.h0.unsqueeze(0).unsqueeze(2).expand(batch_size, self.hid_dim, self.num_lstm_layers) # (batch_size, hid_dim, num_lstm_layers)
    action_cell = self.c0.unsqueeze(0).unsqueeze(2).expand(batch_size, self.hid_dim, self.num_lstm_layers) # (batch_size, hid_dim, num_lstm_layers)

    # prepare bernoulli probability for exposure indicator
    batch_exposure_prob = torch.Tensor([self.exposure_eps] * batch_size).type(self.dtype)

    # main loop
    if actions is None:
      outputs = torch.zeros((self.max_step_length, batch_size, len(self.actions))).type(self.dtype)
      step_length = self.max_step_length
    else:
      outputs = torch.zeros((actions.size()[0], batch_size, len(self.actions))).type(self.dtype)
      step_length = actions.size()[0]
    step_i = 0
    # during decoding, some instances in the batch may terminate earlier than others
    # and we cannot terminate the decoding because one instance terminated
    while step_i < step_length:
      # get action decisions
      summary = self.summarize_states(torch.cat((stack_state, buffer_state, action_state[:, :, -1]), dim=1)) # (batch_size, self.state_dim)
      action_dist = self.softmax(self.state_to_actions(summary)) # (batch_size, len(actions))
      outputs[step_i, :, :] = action_dist.clone()

      # get rid of forbidden actions (only for decoding)
      if actions is None or self.exposure_eps < 1.0:
        forbidden_actions = self.get_valid_actions(batch_size) ^ 1
        num_forbidden_actions = torch.sum(forbidden_actions, dim=1)
        # if all actions except <pad> and <unk> are forbidden for all sentences in the batch,
        # there is no sense continue decoding.
        if (num_forbidden_actions == (len(self.actions) - 2)).all():
          break
        action_dist.masked_fill_(forbidden_actions, -999)

      _, action_i = torch.max(action_dist, dim=1) # (batch_size,)

      # control exposure (only for training)
      if actions is not None:
        oracle_action_i = actions[step_i, :] # (batch_size,)
        batch_exposure = torch.bernoulli(batch_exposure_prob) # (batch_size,)
        action_i = (action_i.float() * (1 - batch_exposure) + oracle_action_i.float() * batch_exposure).long() # (batch_size,)
      # translate action into stack & buffer ops
      # stack_op, buffer_op = self.map_action(action_i.data)
      stack_op = self.stack_action_mapping.index_select(0, action_i)
      buffer_op = self.buffer_action_mapping.index_select(0, action_i)

      stack_state, _ = self.stack(stack_input, stack_op)
      buffer_state = self.buffer(stack_state, buffer_op) # actually nothing will be pushed
      token_stack_state = self.token_stack(stack_input, stack_op)
      token_buffer_state = self.token_buffer(stack_input, buffer_op)
      stack_input = self.token_buffer.head()
      action_input = self.action_emb(action_i) # (batch_size, action_emb_dim)
      action_state, action_cell = self.history(action_input, (action_state, action_cell)) # (batch_size, hid_dim, num_lstm_layers)

      step_i += 1

    return outputs


  def get_valid_actions(self, batch_size):
    action_mask = torch.ones(batch_size, len(self.actions)).type(self.long_dtype).byte()  # (batch_size, len(actions))

    for idx, action_str in enumerate(self.actions):
      # general popping safety
      sa = self.stack_action_mapping[idx].repeat(batch_size)
      ba = self.buffer_action_mapping[idx].repeat(batch_size)
      stack_forbid = ((sa == -1) & (self.stack.pos <= 1))
      buffer_forbid = ((ba == -1) & (self.buffer.pos == 0))
      action_mask[:, idx] = (action_mask[:, idx] & ((stack_forbid | buffer_forbid) ^ 1))

      # transition-system specific operation safety
      if self.transSys == TransitionSystems.AER:
        la_forbid = (((self.buffer.pos == 0) | (self.stack.pos <= 1)) & action_str.startswith("Left-Arc"))
        ra_forbid = (((self.stack.pos == 0) | (self.stack.pos == 0)) & action_str.startswith("Right-Arc"))
        action_mask[:, idx] = (action_mask[:, idx] & ((la_forbid | ra_forbid) ^ 1))

      if self.transSys == TransitionSystems.AH:
        la_forbid = (((self.buffer.pos == 0) | (self.stack.pos <= 1)) & action_str.startswith("Left-Arc"))
        ra_forbid = (self.stack.pos <= 1) & action_str.startswith("Right-Arc")
        action_mask[:, idx] = (action_mask[:, idx] & ((la_forbid | ra_forbid) ^ 1))

    return action_mask


  def set_action_mappings(self):
    for idx, astr in enumerate(self.actions):
      transition = astr.split('|')[0]

      # If the action is unknown, leave the stack and buffer state as-is
      if self.transSys == TransitionSystems.AER:
        self.stack_action_mapping[idx], self.buffer_action_mapping[idx] = AER_map.get(transition, (0, 0))
      elif self.transSys == TransitionSystems.AH:
        self.stack_action_mapping[idx], self.buffer_action_mapping[idx] = AH_map.get(transition, (0, 0))
      else:
        logging.fatal("Unimplemented transition system.")
        raise NotImplementedError
Beispiel #8
0
class BertStackSegmentor(nn.Module):
    def __init__(self, device):
        super(BertStackSegmentor, self).__init__()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.bert_model = BertModel.from_pretrained('bert-base-chinese')

        self.lstm = nn.LSTM(768, 768, batch_first=True, bidirectional=True)
        self.subwStackLSTM = SubwordLSTMCell(768 * 2, 768, device)
        self.wordStackLSTM = StackLSTMCell(2 * 768, 768, device)
        self.cls = nn.Linear(768 * 3, 2)
        self.device = device
        self.subword_action_map = torch.tensor([1, 2, 2]).to(self.device)
        self.word_action_map = torch.tensor([0, 1, 1]).to(self.device)
        self.__init_para()

    def forward(self, insts, golds):
        batch_size, seq_len = golds.shape
        insts, mask = self.__tokenize(batch_size, seq_len, insts, golds)
        hidden_state = self.bert_model(
            insts,
            attention_mask=mask)[0][:, 1:seq_len +
                                    1, :]  # (batch_size, seq_len, hidden_size)
        lstm_output, _ = self.lstm(
            hidden_state)  # (batch_size, seq_len, 768*2)

        self.subwStackLSTM.init_stack(2 * seq_len + 2, batch_size)
        self.wordStackLSTM.init_stack(seq_len + 1, batch_size)

        pred = [
            torch.tensor([[[-1., 1.]]]).expand(
                (batch_size, 1, 2)).to(self.device)
        ]
        for idx in range(1, seq_len, 1):
            subword = self.subwStackLSTM(
                lstm_output[:, idx - 1, :])  # (batch_size, 768*2)
            word_repre, _ = self.wordStackLSTM(subword)  # (batch_size, 768)
            output = self.cls(
                torch.cat([word_repre, lstm_output[:, idx, :]],
                          1))  # (batch_size, 2)

            if self.training:
                subwordOP = self.subword_action_map.index_select(
                    0, golds[:, idx])
                wordOP = self.word_action_map.index_select(0, golds[:, idx])
            else:
                subwordOP = self.subword_action_map.index_select(
                    0, torch.argmax(output, 1))
                wordOP = self.word_action_map.index_select(
                    0, torch.argmax(output, 1))
            self.subwStackLSTM.update_pos(subwordOP)
            self.wordStackLSTM.update_pos(wordOP)

            pred.append(output.unsqueeze(1))
        return torch.cat(pred, 1)  # (batch_size, seq_len, 2)

    def __tokenize(self, batch_size: int, seq_len: int, insts: List[str],
                   golds: torch.Tensor):
        insts = [
            self.tokenizer.encode(inst) + [self.tokenizer.pad_token_id] *
            (seq_len - (len(inst) + 1) // 2) for inst in insts
        ]
        insts = torch.tensor(insts).to(self.device)
        assert insts.shape[0] == batch_size and insts.shape[
            1] == seq_len + 2, 'insts tokenizing goes wrong.'
        mask = torch.ones((batch_size, seq_len + 2),
                          dtype=torch.long).to(self.device)
        mask[:, 2:] = golds != Constants.actionPadId
        return insts, mask

    def __init_para(self):
        nn.init.xavier_uniform_(self.cls.weight)
        nn.init.uniform_(self.cls.bias)