Example #1
0
class EncoderDecoderModel:
    def __init__(self,
                 src_vocab,
                 trg_vocab,
                 n_embed=256,
                 n_hidden=512,
                 algorithm='Adam'):
        self.src_vocab = src_vocab
        self.trg_vocab = trg_vocab
        self.n_embed = n_embed
        self.n_hidden = n_hidden
        self.algorithm = algorithm
        self.model = FunctionSet(embed_x=F.EmbedID(len(src_vocab), n_embed),
                                 en_x_to_h=F.Linear(n_embed, 4 * n_hidden),
                                 en_h_to_h=F.Linear(n_hidden, 4 * n_hidden),
                                 en_h_to_de_h=F.Linear(n_hidden, 4 * n_hidden),
                                 de_h_to_embed_y=F.Linear(n_hidden, n_embed),
                                 embed_y_to_y=F.Linear(n_embed,
                                                       len(trg_vocab)),
                                 y_to_h=F.EmbedID(len(trg_vocab),
                                                  4 * n_hidden),
                                 de_h_to_h=F.Linear(n_hidden, 4 * n_hidden))

    def get_model(self):
        return self.model

    def forward(self, src_batch, trg_batch):
        # encode
        n_batch = len(src_batch)
        lstm_c = self.initialize_state(n_batch)
        src_sent_words = len(src_batch[0])
        for i in range(src_sent_words):
            print np.array([src_batch[k][i] for k in range(n_batch)],
                           dtype=np.int32)
            x = Variable(
                np.array([src_batch[k][i] for k in range(n_batch)],
                         dtype=np.int32))
            en_x = F.tanh(self.model.embed_x(x))
            if i == 0:
                lstm_c, en_h = F.lstm(lstm_c, self.model.en_x_to_h(en_x))
            else:
                lstm_c, en_h = F.lstm(
                    lstm_c,
                    self.model.en_x_to_h(en_x) + self.model.en_h_to_h(en_h))

        # decode
        hyp_sents = [[] for i in range(n_batch)]
        accum_loss = Variable(np.zeros(()).astype(np.float32))
        trg_sent_words = len(trg_batch[0])
        lstm_c, de_h = F.lstm(lstm_c, self.model.en_h_to_de_h(en_h))
        for i in range(trg_sent_words):
            embed_y = F.tanh(self.model.de_h_to_embed_y(de_h))
            y = self.model.embed_y_to_y(embed_y)
            t = Variable(
                np.array([trg_batch[k][i] for k in range(n_batch)],
                         dtype=np.int32))
            accum_loss += F.softmax_cross_entropy(y, t)
            output = y.data.argmax(1)
            for k in range(n_batch):
                hyp_sents[k].append(output[k])
            lstm_c, de_h = F.lstm(
                lstm_c,
                self.model.de_h_to_h(de_h) + self.model.y_to_h(t))
        return hyp_sents, accum_loss

    def fit(self, src_batch, trg_batch):
        self.optimizer.zero_grads()
        hyp_sents, accum_loss = self.forward(src_batch, trg_batch)
        accum_loss.backward()
        self.optimizer.clip_grads(10)
        self.optimizer.update()
        return hyp_sents

    def predict(self, src_batch, sent_len_limit):
        # encode
        n_batch = len(src_batch)
        lstm_c = self.initialize_state(n_batch)
        src_sent_words = len(src_batch[0])
        for i in range(src_sent_words):
            x = Variable(
                np.array([src_batch[k][i] for k in range(n_batch)],
                         dtype=np.int32))
            en_x = F.tanh(self.model.embed_x(x))
            if i == 0:
                lstm_c, en_h = F.lstm(lstm_c, self.model.en_x_to_h(en_x))
            else:
                lstm_c, en_h = F.lstm(
                    lstm_c,
                    self.model.en_x_to_h(en_x) + self.model.en_h_to_h(en_h))

        # decode
        lstm_c, de_h = F.lstm(lstm_c, self.model.en_h_to_de_h(en_h))
        hyp_sents = [[] for i in range(n_batch)]

        # output the highest probability words
        while len(hyp_sents[0]) < sent_len_limit:
            embed_y = F.tanh(self.model.de_h_to_embed_y(de_h))
            y = self.model.embed_y_to_y(embed_y)
            output = y.data.argmax(1)
            for k in range(n_batch):
                hyp_sents[k].append(output[k])
            output = Variable(output)
            lstm_c, de_h = F.lstm(
                lstm_c,
                self.model.de_h_to_h(de_h) + self.model.y_to_h(output))
            if all(hyp_sents[k][-1] == trg_vocab['</s>']
                   for k in range(n_batch)):
                break
        return hyp_sents

    def initialize_optimizer(self, lr=0.5):
        if self.algorithm == 'SGD':
            self.optimizer = optimizers.SGD(lr=lr)
        elif self.algorithm == 'Adam':
            self.optimizer = optimizers.Adam()
        elif self.algorithm == 'Adagrad':
            self.optimizer = optimizers.AdaGrad()
        elif self.algorithm == 'Adadelta':
            self.optimizer = optimizers.AdaDelta()
        else:
            raise AssertionError('this algorithm is not available')
        self.optimizer.setup(self.model)

    def initialize_state(self, n_batch):
        return Variable(np.zeros((n_batch, self.n_hidden), dtype=np.float32))