示例#1
0
def build_tag_graph():
    print('build graph..', file=sys.stderr)

    # (batch_size, sentence_length)
    x = T.imatrix(name='sentence')

    # (batch_size, sentence_length)
    y = T.imatrix(name='tag')

    # Lookup parameters for word embeddings
    embedding_table = Embedding(nwords, args.WEMBED_SIZE)

    # bi-lstm
    lstm = BiLSTM(args.WEMBED_SIZE, args.HIDDEN_SIZE, return_sequences=True)

    # MLP
    W_mlp_hidden = uniform((args.HIDDEN_SIZE * 2, args.MLP_SIZE),
                           name='W_mlp_hidden')
    W_mlp = uniform((args.MLP_SIZE, ntags), name='W_mlp')

    # (batch_size, sentence_length, embedding_dim)
    sent_embed, sent_mask = embedding_table(x, mask_zero=True)

    # (batch_size, sentence_length, lstm_hidden_dim)
    lstm_output = lstm(sent_embed, mask=sent_mask)

    # (batch_size, sentence_length, ntags)
    mlp_output = T.dot(T.tanh(T.dot(lstm_output, W_mlp_hidden)), W_mlp)

    # (batch_size * sentence_length, ntags)
    mlp_output = mlp_output.reshape(
        (mlp_output.shape[0] * mlp_output.shape[1], -1))

    tag_prob_f = T.log(T.nnet.softmax(mlp_output))

    y_f = y.flatten()
    mask_f = sent_mask.flatten()

    tag_nll = -tag_prob_f[T.arange(tag_prob_f.shape[0]), y_f] * mask_f

    loss = tag_nll.sum()

    params = embedding_table.params + lstm.params + [W_mlp_hidden, W_mlp]
    updates = Adam().get_updates(params, loss)
    train_loss_func = theano.function([x, y], loss, updates=updates)

    # build the decoding graph
    tag_prob = tag_prob_f.reshape((x.shape[0], x.shape[1], -1))
    decode_func = theano.function([x], tag_prob)

    return train_loss_func, decode_func
示例#2
0
    def __init__(self):
        # self.node_embedding = Embedding(config.node_num, config.node_embed_dim, name='node_embed')

        self.query_embedding = Embedding(config.source_vocab_size, config.word_embed_dim, name='query_embed')

        if config.encoder_lstm == 'bilstm':
            self.query_encoder_lstm = BiLSTM(config.word_embed_dim, config.encoder_hidden_dim / 2, return_sequences=True,
                                             name='query_encoder_lstm')
        else:
            self.query_encoder_lstm = LSTM(config.word_embed_dim, config.encoder_hidden_dim, return_sequences=True,
                                           name='query_encoder_lstm')

        self.decoder_lstm = CondAttLSTM(config.rule_embed_dim + config.node_embed_dim + config.rule_embed_dim,
                                        config.decoder_hidden_dim, config.encoder_hidden_dim, config.attention_hidden_dim,
                                        name='decoder_lstm')

        self.src_ptr_net = PointerNet()

        self.terminal_gen_softmax = Dense(config.decoder_hidden_dim, 2, activation='softmax', name='terminal_gen_softmax')

        self.rule_embedding_W = initializations.get('normal')((config.rule_num, config.rule_embed_dim), name='rule_embedding_W', scale=0.1)
        self.rule_embedding_b = shared_zeros(config.rule_num, name='rule_embedding_b')

        self.node_embedding = initializations.get('normal')((config.node_num, config.node_embed_dim), name='node_embed', scale=0.1)

        self.vocab_embedding_W = initializations.get('normal')((config.target_vocab_size, config.rule_embed_dim), name='vocab_embedding_W', scale=0.1)
        self.vocab_embedding_b = shared_zeros(config.target_vocab_size, name='vocab_embedding_b')

        # decoder_hidden_dim -> action embed
        self.decoder_hidden_state_W_rule = Dense(config.decoder_hidden_dim, config.rule_embed_dim, name='decoder_hidden_state_W_rule')

        # decoder_hidden_dim -> action embed
        self.decoder_hidden_state_W_token= Dense(config.decoder_hidden_dim + config.encoder_hidden_dim, config.rule_embed_dim,
                                                 name='decoder_hidden_state_W_token')

        # self.rule_encoder_lstm.params
        self.params = self.query_embedding.params + self.query_encoder_lstm.params + \
                      self.decoder_lstm.params + self.src_ptr_net.params + self.terminal_gen_softmax.params + \
                      [self.rule_embedding_W, self.rule_embedding_b, self.node_embedding, self.vocab_embedding_W, self.vocab_embedding_b] + \
                      self.decoder_hidden_state_W_rule.params + self.decoder_hidden_state_W_token.params

        self.srng = RandomStreams()
示例#3
0
    def __init__(self):
        # self.node_embedding = Embedding(config.node_num, config.node_embed_dim, name='node_embed')

        self.query_embedding = Embedding(config.source_vocab_size,
                                         config.word_embed_dim,
                                         name='query_embed')

        encoder_dim = config.word_embed_dim
        logging.info("Concatenation type: %s" % config.concat_type)
        logging.info("Include canon_id matrix: %s" % config.include_cid)
        if config.concat_type == 'basic':
            encoder_dim += 2
            if config.include_cid == True:
                encoder_dim += 1
        else:
            # define layers
            self.query_phrase_embedding = Embedding(14,
                                                    8,
                                                    name='query_phrase_embed')
            self.query_pos_embedding = Embedding(44,
                                                 32,
                                                 name='query_pos_embed')
            self.query_canon_embedding = Embedding(
                102, 64, name='query_canon_embedding')
            aug_dim = 8 + 32
            if config.include_cid == True:
                aug_dim += 64
            self.projector = Dense(config.word_embed_dim + aug_dim,
                                   config.word_embed_dim,
                                   activation='linear',
                                   name='concat_projector')

        if config.encoder == 'bilstm':
            self.query_encoder_lstm = BiLSTM(encoder_dim,
                                             config.encoder_hidden_dim / 2,
                                             return_sequences=True,
                                             name='query_encoder_lstm')
        else:
            self.query_encoder_lstm = LSTM(encoder_dim,
                                           config.encoder_hidden_dim,
                                           return_sequences=True,
                                           name='query_encoder_lstm')

        self.decoder_lstm = CondAttLSTM(config.rule_embed_dim +
                                        config.node_embed_dim +
                                        config.rule_embed_dim,
                                        config.decoder_hidden_dim,
                                        config.encoder_hidden_dim,
                                        config.attention_hidden_dim,
                                        name='decoder_lstm')
        self.src_ptr_net = PointerNet()

        self.terminal_gen_softmax = Dense(config.decoder_hidden_dim,
                                          2,
                                          activation='softmax',
                                          name='terminal_gen_softmax')

        self.rule_embedding_W = initializations.get('normal')(
            (config.rule_num, config.rule_embed_dim),
            name='rule_embedding_W',
            scale=0.1)
        self.rule_embedding_b = shared_zeros(config.rule_num,
                                             name='rule_embedding_b')

        self.node_embedding = initializations.get('normal')(
            (config.node_num, config.node_embed_dim),
            name='node_embed',
            scale=0.1)

        self.vocab_embedding_W = initializations.get('normal')(
            (config.target_vocab_size, config.rule_embed_dim),
            name='vocab_embedding_W',
            scale=0.1)
        self.vocab_embedding_b = shared_zeros(config.target_vocab_size,
                                              name='vocab_embedding_b')

        # decoder_hidden_dim -> action embed
        self.decoder_hidden_state_W_rule = Dense(
            config.decoder_hidden_dim,
            config.rule_embed_dim,
            name='decoder_hidden_state_W_rule')

        # decoder_hidden_dim -> action embed
        self.decoder_hidden_state_W_token = Dense(
            config.decoder_hidden_dim + config.encoder_hidden_dim,
            config.rule_embed_dim,
            name='decoder_hidden_state_W_token')

        # self.rule_encoder_lstm.params
        self.params = self.query_embedding.params + self.query_encoder_lstm.params + \
                      self.decoder_lstm.params + self.src_ptr_net.params + self.terminal_gen_softmax.params + \
                      [self.rule_embedding_W, self.rule_embedding_b, self.node_embedding, self.vocab_embedding_W, self.vocab_embedding_b] + \
                      self.decoder_hidden_state_W_rule.params + self.decoder_hidden_state_W_token.params

        self.srng = RandomStreams()
示例#4
0
def build_tag_graph():
    print('build graph..', file=sys.stderr)

    # (sentence_length)
    # word indices for a sentence
    x = T.ivector(name='sentence')

    # (sentence_length, max_char_num_per_word)
    # character indices for each word in a sentence
    x_chars = T.imatrix(name='sent_word_chars')

    # (sentence_length)
    # target tag
    y = T.ivector(name='tag')

    # Lookup parameters for word embeddings
    word_embeddings = Embedding(nwords,
                                args.WEMBED_SIZE,
                                name='word_embeddings')

    # Lookup parameters for character embeddings
    char_embeddings = Embedding(nchars,
                                args.CEMBED_SIZE,
                                name='char_embeddings')

    # lstm for encoding word characters
    char_lstm = BiLSTM(args.CEMBED_SIZE,
                       int(args.WEMBED_SIZE / 2),
                       name='char_lstm')

    # bi-lstm
    lstm = BiLSTM(args.WEMBED_SIZE,
                  args.HIDDEN_SIZE,
                  return_sequences=True,
                  name='lstm')

    # MLP
    W_mlp_hidden = uniform((args.HIDDEN_SIZE * 2, args.MLP_SIZE),
                           name='W_mlp_hidden')
    W_mlp = uniform((args.MLP_SIZE, ntags), name='W_mlp')

    # def get_word_embed_from_chars(word_chars):
    #   # (max_char_num_per_word, char_embed_dim)
    #   # (max_char_num_per_word)
    #   word_char_embeds, word_char_masks = char_embeddings(word_chars, mask_zero=True)
    #   word_embed = char_lstm(T.unbroadcast(word_char_embeds[None, :, :], 0), mask=T.unbroadcast(word_char_masks[None, :], 0))[0]
    #
    #   return word_embed

    # def word_embed_look_up_step(word_id, word_chars):
    #   word_embed = ifelse(T.eq(word_id, UNK),
    #             get_word_embed_from_chars(word_chars),  # if it's a unk
    #             word_embeddings(word_id))
    #
    #   return word_embed

    word_embed_src = T.eq(x, UNK).astype('float32')[:, None]

    # (sentence_length, word_embedding_dim)
    word_embed = word_embeddings(x)

    # (sentence_length, max_char_num_per_word, char_embed_dim)
    # (sentence_length, max_char_num_per_word)
    word_char_embeds, word_char_masks = char_embeddings(x_chars,
                                                        mask_zero=True)

    # (sentence_length, word_embedding_dim)
    word_embed_from_char = char_lstm(word_char_embeds, mask=word_char_masks)

    sent_embed = word_embed_src * word_embed_from_char + (
        1 - word_embed_src) * word_embed

    # # (sentence_length, embedding_dim)
    # sent_embed, _ = theano.scan(word_embed_look_up_step, sequences=[x, x_chars])

    # (sentence_length, lstm_hidden_dim)
    lstm_output = lstm(T.unbroadcast(sent_embed[None, :, :], 0))[0]

    # (sentence_length, ntags)
    mlp_output = T.dot(T.tanh(T.dot(lstm_output, W_mlp_hidden)), W_mlp)

    tag_prob = T.log(T.nnet.softmax(mlp_output))

    tag_nll = -tag_prob[T.arange(tag_prob.shape[0]), y]

    loss = tag_nll.sum()

    params = word_embeddings.params + char_embeddings.params + char_lstm.params + lstm.params + [
        W_mlp_hidden, W_mlp
    ]
    updates = Adam().get_updates(params, loss)
    train_loss_func = theano.function([x, x_chars, y], loss, updates=updates)

    # build the decoding graph
    decode_func = theano.function([x, x_chars], tag_prob)

    return train_loss_func, decode_func