class LayerNorm(Model): def __init__(self, eps=1e-6): self.eps = eps self.pgain = Parameter() self.pbias = Parameter() self.scan_attributes() def init(self, d_model): self.pgain.init([1, d_model], I.Constant(1)) self.pbias.init([1, d_model], I.Constant(0)) @function_type def __call__(self, x, train): seq_len = x.shape()[0] d_model = x.shape()[1] gain = self.F.broadcast(self.F.parameter(self.pgain), 0, seq_len) bias = self.F.broadcast(self.F.parameter(self.pbias), 0, seq_len) mean = self.F.mean(x, 1) std = self.F.sqrt(self.F.mean(x * x, 1) - mean * mean) mean = self.F.broadcast(mean, 1, d_model) std = self.F.broadcast(std, 1, d_model) return gain * (x - mean) / (std + self.eps) + bias
class PositionwiseFeedForward(Model): def __init__(self, dropout): self.dropout = dropout self.pw1 = Parameter() self.pb1 = Parameter() self.pw2 = Parameter() self.pb2 = Parameter() self.scan_attributes() def init(self, d_model, d_ff): self.pw1.init([d_model, d_ff], I.XavierUniform()) self.pb1.init([1, d_ff], I.XavierUniform()) self.pw2.init([d_ff, d_model], I.XavierUniform()) self.pb2.init([1, d_model], I.XavierUniform()) @function_type def __call__(self, x, train): seq_len = x.shape()[0] w1 = self.F.parameter(self.pw1) w2 = self.F.parameter(self.pw2) b1 = self.F.broadcast(self.F.parameter(self.pb1), 0, seq_len) b2 = self.F.broadcast(self.F.parameter(self.pb2), 0, seq_len) h = self.F.dropout(self.F.relu(x @ w1 + b1), self.dropout, train) return h @ w2 + b2
class TransformerEmbeddings(Model): def __init__(self, max_len, dropout): self.max_len = max_len self.dropout = dropout self.pe = None self.plookup = Parameter() self.pby = Parameter() self.scan_attributes() def init(self, vocab, d_model): self.plookup.init([d_model, vocab], I.XavierUniform()) self.pby.init([1, vocab], I.XavierUniform()) @function_type def encode(self, seq, train): lookup = self.F.parameter(self.plookup) d_model = lookup.shape()[0] if self.pe is None: self.pe = self.positional_encoding() embed = [] for w in seq: e = self.F.pick(lookup, w, 1) embed.append(e) embed_tensor = self.F.transpose(self.F.concat(embed, 1)) embed_tensor *= math.sqrt(d_model) pos = self.F.input(self.pe[:len(seq)]) pe = self.F.dropout(embed_tensor + pos, self.dropout, train) return pe @function_type def decode(self, x, train): # x: [seq_len, d_model] w = self.F.parameter(self.plookup) # [d_model, vocab] by = self.F.broadcast(self.F.parameter(self.pby), 0, x.shape()[0]) # [seq_len, vocab] return x @ w + by # [seq_len, vocab] def positional_encoding(self): d_model = self.plookup.shape()[0] pe = np.zeros((self.max_len, d_model)) position = np.expand_dims(np.arange(0, self.max_len), axis=1) div_term = np.exp( np.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) div_term = np.expand_dims(div_term, axis=0) pe[:, 0::2] = np.sin(position * div_term) pe[:, 1::2] = np.cos(position * div_term) return pe
class MultiHeadAttention(Model): def __init__(self, n_heads, dropout): self.dropout = dropout self.n_heads = n_heads self.pwq = Parameter() self.pwk = Parameter() self.pwv = Parameter() self.pwo = Parameter() self.attention = ScaledDotProductAttention(dropout) self.scan_attributes() def init(self, d_model): assert d_model % self.n_heads == 0, 'd_model must be a multiple of n_heads.' self.pwq.init([d_model, d_model], I.XavierUniform()) self.pwk.init([d_model, d_model], I.XavierUniform()) self.pwv.init([d_model, d_model], I.XavierUniform()) self.pwo.init([d_model, d_model], I.XavierUniform()) def split_heads(self, x): d_model = x.shape()[1] d_k = d_model // self.n_heads ret = [ self.F.slice(x, 1, i * d_k, (i + 1) * d_k) for i in range(self.n_heads) ] # [n_heads, x_len, d_k] return ret @function_type def __call__(self, query, key, value, mask, train): wq = self.F.parameter(self.pwq) wk = self.F.parameter(self.pwk) wv = self.F.parameter(self.pwv) wo = self.F.parameter(self.pwo) query = query @ wq key = key @ wk value = value @ wv query = self.split_heads(query) key = self.split_heads(key) value = self.split_heads(value) if mask is not None: mask = 2000 * self.F.input(mask) if mask.shape()[0] != query[0].shape()[0]: mask = self.F.broadcast(mask, 0, query[0].shape()[0]) heads = [] for q, k, v in zip(query, key, value): head = self.attention(q, k, v, mask, train) heads.append(head) heads = self.F.concat(heads, 1) return heads @ wo
class EncoderDecoder(Model): """Standard encoder-decoder model.""" def __init__(self): self.dropout_rate = DROPOUT_RATE self.psrc_lookup = Parameter() self.ptrg_lookup = Parameter() self.pwhy = Parameter() self.pby = Parameter() self.src_lstm = LSTM() self.trg_lstm = LSTM() self.scan_attributes() def init(self, src_vocab_size, trg_vocab_size, embed_size, hidden_size): """Creates a new EncoderDecoder object.""" self.psrc_lookup.init([embed_size, src_vocab_size], I.XavierUniform()) self.ptrg_lookup.init([embed_size, trg_vocab_size], I.XavierUniform()) self.pwhy.init([trg_vocab_size, hidden_size], I.XavierUniform()) self.pby.init([trg_vocab_size], I.Constant(0)) self.src_lstm.init(embed_size, hidden_size) self.trg_lstm.init(embed_size, hidden_size) def encode(self, src_batch, train): """Encodes source sentences and prepares internal states.""" # Reversed encoding. src_lookup = F.parameter(self.psrc_lookup) self.src_lstm.restart() for it in src_batch: x = F.pick(src_lookup, it, 1) x = F.dropout(x, self.dropout_rate, train) self.src_lstm.forward(x) # Initializes decoder states. self.trg_lookup = F.parameter(self.ptrg_lookup) self.why = F.parameter(self.pwhy) self.by = F.parameter(self.pby) self.trg_lstm.restart(self.src_lstm.get_c(), self.src_lstm.get_h()) def decode_step(self, trg_words, train): """One step decoding.""" x = F.pick(self.trg_lookup, trg_words, 1) x = F.dropout(x, self.dropout_rate, train) h = self.trg_lstm.forward(x) h = F.dropout(h, self.dropout_rate, train) return self.why @ h + self.by def loss(self, trg_batch, train): """Calculates loss values.""" losses = [] for i in range(len(trg_batch) - 1): y = self.decode_step(trg_batch[i], train) losses.append(F.softmax_cross_entropy(y, trg_batch[i + 1], 0)) return F.batch.mean(F.sum(losses))
class LSTM(Model): """LSTM cell.""" def __init__(self): self._pwxh = Parameter(); self._pwhh = Parameter(); self._pbh = Parameter(); self.scan_attributes() def init(self, in_size, out_size): """Creates a new LSTM.""" self._pwxh.init([4 * out_size, in_size], I.XavierUniform()) self._pwhh.init([4 * out_size, out_size], I.XavierUniform()) self._pbh.init([4 * out_size], I.Constant(0)) def reset(self, init_c = Node(), init_h = Node()): """Initializes internal states.""" out_size = self._pwhh.shape()[1] self._wxh = F.parameter(self._pwxh) self._whh = F.parameter(self._pwhh) self._bh = F.parameter(self._pbh) self._c = init_c if init_c.valid() else F.zeros([out_size]) self._h = init_h if init_h.valid() else F.zeros([out_size]) def forward(self, x): """One step forwarding.""" out_size = self._pwhh.shape()[1] u = self._wxh @ x + self._whh @ self._h + self._bh i = F.sigmoid(F.slice(u, 0, 0, out_size)) f = F.sigmoid(F.slice(u, 0, out_size, 2 * out_size)); o = F.sigmoid(F.slice(u, 0, 2 * out_size, 3 * out_size)); j = F.tanh(F.slice(u, 0, 3 * out_size, 4 * out_size)); self._c = i * j + f * self._c; self._h = o * F.tanh(self._c); return self._h; def get_c(self): """Retrieves current internal cell state.""" return self._c def get_h(self): """Retrieves current hidden value.""" return self._h
class EncoderDecoder(Model): def __init__(self, dropout_rate): self.dropout_rate_ = dropout_rate self.psrc_lookup_ = Parameter() self.ptrg_lookup_ = Parameter() self.pwfbw_ = Parameter() self.pwhw_ = Parameter() self.pwwe_ = Parameter() self.pwhj_ = Parameter() self.pbj_ = Parameter() self.pwjy_ = Parameter() self.pby_ = Parameter() self.src_fw_lstm_ = LSTM() self.src_bw_lstm_ = LSTM() self.trg_lstm_ = LSTM() self.scan_attributes() def init(self, src_vocab_size, trg_vocab_size, embed_size, hidden_size): self.psrc_lookup_.init([embed_size, src_vocab_size], I.XavierUniform()) self.ptrg_lookup_.init([embed_size, trg_vocab_size], I.XavierUniform()) self.pwfbw_.init([2*hidden_size, hidden_size], I.XavierUniform()) self.pwhw_.init([hidden_size, hidden_size], I.XavierUniform()) self.pwwe_.init([hidden_size], I.XavierUniform()) self.pwhj_.init([embed_size, hidden_size], I.XavierUniform()) self.pbj_.init([embed_size], I.Constant(0)) self.pwjy_.init([trg_vocab_size, embed_size], I.XavierUniform()) self.pby_.init([trg_vocab_size], I.Constant(0)) self.src_fw_lstm_.init(embed_size, hidden_size) self.src_bw_lstm_.init(embed_size, hidden_size) self.trg_lstm_.init(embed_size+hidden_size*2, hidden_size) def encode(self, src_batch, train): # Embedding lookup. src_lookup = F.parameter(self.psrc_lookup_) e_list = [] for x in src_batch: e = F.pick(src_lookup, x, 1) e = F.dropout(e, self.dropout_rate_, train) e_list.append(e) # Forward encoding self.src_fw_lstm_.reset() f_list = [] for e in e_list: f = self.src_fw_lstm_.forward(e) f = F.dropout(f, self.dropout_rate_, train) f_list.append(f) # Backward encoding self.src_bw_lstm_.reset() b_list = [] for e in reversed(e_list): b = self.src_bw_lstm_.forward(e) b = F.dropout(b, self.dropout_rate_, train) b_list.append(b) b_list.reverse() # Concatenates RNN states. fb_list = [F.concat([f_list[i], b_list[i]], 0) for i in range(len(src_batch))] self.concat_fb = F.concat(fb_list, 1) self.t_concat_fb = F.transpose(self.concat_fb) # Initializes decode states. self.wfbw_ = F.parameter(self.pwfbw_) self.whw_ = F.parameter(self.pwhw_) self.wwe_ = F.parameter(self.pwwe_) self.trg_lookup_ = F.parameter(self.ptrg_lookup_) self.whj_ = F.parameter(self.pwhj_) self.bj_ = F.parameter(self.pbj_) self.wjy_ = F.parameter(self.pwjy_) self.by_ = F.parameter(self.pby_) self.trg_lstm_.reset() # One step decoding. def decode_step(self, trg_words, train): sentence_len = self.concat_fb.shape()[1] b = self.whw_ @ self.trg_lstm_.get_h() b = F.reshape(b, Shape([1, b.shape()[0]])) b = F.broadcast(b, 0, sentence_len) x = F.tanh(self.t_concat_fb @ self.wfbw_ + b) atten_prob = F.softmax(x @ self.wwe_, 0) c = self.concat_fb @ atten_prob e = F.pick(self.trg_lookup_, trg_words, 1) e = F.dropout(e, self.dropout_rate_, train) h = self.trg_lstm_.forward(F.concat([e, c], 0)) h = F.dropout(h, self.dropout_rate_, train) j = F.tanh(self.whj_ @ h + self.bj_) return self.wjy_ @ j + self.by_ # Calculates the loss function over given target sentences. def loss(self, trg_batch, train): losses = [] for i in range(len(trg_batch)-1): y = self.decode_step(trg_batch[i], train) loss = F.softmax_cross_entropy(y, trg_batch[i+1], 0) losses.append(loss) return F.batch.mean(F.sum(losses))
class AttentionalEncoderDecoder(Model): """Encoder-decoder translation model with dot-attention.""" def __init__(self): self.dropout_rate = DROPOUT_RATE self.psrc_lookup = Parameter() self.ptrg_lookup = Parameter() self.pwhj = Parameter() self.pbj = Parameter() self.pwjy = Parameter() self.pby = Parameter() self.src_fw_lstm = LSTM() self.src_bw_lstm = LSTM() self.trg_lstm = LSTM() self.add_all_parameters() self.add_all_submodels() def init(self, src_vocab_size, trg_vocab_size, embed_size, hidden_size): """Creates a new AttentionalEncoderDecoder object.""" self.psrc_lookup.init([embed_size, src_vocab_size], I.XavierUniform()) self.ptrg_lookup.init([embed_size, trg_vocab_size], I.XavierUniform()) self.pwhj.init([embed_size, 2 * hidden_size], I.XavierUniform()) self.pbj.init([embed_size], I.Constant(0)) self.pwjy.init([trg_vocab_size, embed_size], I.XavierUniform()) self.pby.init([trg_vocab_size], I.Constant(0)) self.src_fw_lstm.init(embed_size, hidden_size) self.src_bw_lstm.init(embed_size, hidden_size) self.trg_lstm.init(2 * embed_size, hidden_size) def encode(self, src_batch, train): """Encodes source sentences and prepares internal states.""" # Embedding lookup. src_lookup = F.parameter(self.psrc_lookup) e_list = [] for x in src_batch: e = F.pick(src_lookup, x, 1) e = F.dropout(e, self.dropout_rate, train) e_list.append(e) # Forward encoding self.src_fw_lstm.restart() f_list = [] for e in e_list: f = self.src_fw_lstm.forward(e) f = F.dropout(f, self.dropout_rate, train) f_list.append(f) # Backward encoding self.src_bw_lstm.restart() b_list = [] for e in reversed(e_list): b = self.src_bw_lstm.forward(e) b = F.dropout(b, self.dropout_rate, train) b_list.append(b) b_list.reverse() # Concatenates RNN states. fb_list = [f_list[i] + b_list[i] for i in range(len(src_batch))] self.concat_fb = F.concat(fb_list, 1) self.t_concat_fb = F.transpose(self.concat_fb) # Initializes decode states. embed_size = self.psrc_lookup.shape()[0] self.trg_lookup = F.parameter(self.ptrg_lookup) self.whj = F.parameter(self.pwhj) self.bj = F.parameter(self.pbj) self.wjy = F.parameter(self.pwjy) self.by = F.parameter(self.pby) self.feed = F.zeros([embed_size]) self.trg_lstm.restart( self.src_fw_lstm.get_c() + self.src_bw_lstm.get_c(), self.src_fw_lstm.get_h() + self.src_bw_lstm.get_h()) def decode_step(self, trg_words, train): """One step decoding.""" e = F.pick(self.trg_lookup, trg_words, 1) e = F.dropout(e, self.dropout_rate, train) h = self.trg_lstm.forward(F.concat([e, self.feed], 0)) h = F.dropout(h, self.dropout_rate, train) atten_probs = F.softmax(self.t_concat_fb @ h, 0) c = self.concat_fb @ atten_probs self.feed = F.tanh(self.whj @ F.concat([h, c], 0) + self.bj) return self.wjy @ self.feed + self.by def loss(self, trg_batch, train): """Calculates loss values.""" losses = [] for i in range(len(trg_batch) - 1): y = self.decode_step(trg_batch[i], train) loss = F.softmax_cross_entropy(y, trg_batch[i + 1], 0) losses.append(loss) return F.batch.mean(F.sum(losses))