def model_graph(features, labels, mode, params): hidden_size = params.hidden_size src_seq = features["source"] tgt_seq = features["target"] src_len = features["source_length"] tgt_len = features["target_length"] src_mask = tf.sequence_mask(src_len, maxlen=tf.shape(features["source"])[1], dtype=tf.float32) tgt_mask = tf.sequence_mask(tgt_len, maxlen=tf.shape(features["target"])[1], dtype=tf.float32) src_embedding, tgt_embedding, weights = get_weights(params) bias = tf.get_variable("bias", [hidden_size]) # id => embedding # src_seq: [batch, max_src_length] # tgt_seq: [batch, max_tgt_length] inputs = tf.gather(src_embedding, src_seq) * (hidden_size**0.5) targets = tf.gather(tgt_embedding, tgt_seq) * (hidden_size**0.5) inputs = inputs * tf.expand_dims(src_mask, -1) targets = targets * tf.expand_dims(tgt_mask, -1) # Preparing encoder & decoder input encoder_input = tf.nn.bias_add(inputs, bias) encoder_input = layers.attention.add_timing_signal(encoder_input) enc_attn_bias = layers.attention.attention_bias(src_mask, "masking") dec_attn_bias = layers.attention.attention_bias( tf.shape(targets)[1], "causal") # Shift left decoder_input = tf.pad(targets, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] decoder_input = layers.attention.add_timing_signal(decoder_input) if params.residual_dropout: keep_prob = 1.0 - params.residual_dropout encoder_input = tf.nn.dropout(encoder_input, keep_prob) decoder_input = tf.nn.dropout(decoder_input, keep_prob) encoder_output = transformer_encoder(encoder_input, enc_attn_bias, params) decoder_output = transformer_decoder(decoder_input, encoder_output, dec_attn_bias, enc_attn_bias, params) # inference mode, take the last position if mode == "infer": decoder_output = decoder_output[:, -1, :] logits = tf.matmul(decoder_output, weights, False, True) return logits # [batch, length, channel] => [batch * length, vocab_size] decoder_output = tf.reshape(decoder_output, [-1, hidden_size]) logits = tf.matmul(decoder_output, weights, False, True) # label smoothing ce = layers.nn.smoothed_softmax_cross_entropy_with_logits( logits=logits, labels=labels, smoothing=params.label_smoothing, normalize=True) ce = tf.reshape(ce, tf.shape(tgt_seq)) loss = get_loss(features, params, ce, tgt_mask) return loss
def model_graph(features, labels, params): src_vocab_size = len(params.vocabulary["source"]) tgt_vocab_size = len(params.vocabulary["target"]) src_seq = features["source"] tgt_seq = features["target"] if params.reverse_source: src_seq = tf.reverse_sequence(src_seq, seq_dim=1, seq_lengths=features["source_length"]) with tf.device("/cpu:0"): with tf.variable_scope("source_embedding"): src_emb = tf.get_variable("embedding", [src_vocab_size, params.embedding_size]) src_bias = tf.get_variable("bias", [params.embedding_size]) src_inputs = tf.nn.embedding_lookup(src_emb, src_seq) with tf.variable_scope("target_embedding"): tgt_emb = tf.get_variable("embedding", [tgt_vocab_size, params.embedding_size]) tgt_bias = tf.get_variable("bias", [params.embedding_size]) tgt_inputs = tf.nn.embedding_lookup(tgt_emb, tgt_seq) src_inputs = tf.nn.bias_add(src_inputs, src_bias) tgt_inputs = tf.nn.bias_add(tgt_inputs, tgt_bias) if params.dropout and not params.use_variational_dropout: src_inputs = tf.nn.dropout(src_inputs, 1.0 - params.dropout) tgt_inputs = tf.nn.dropout(tgt_inputs, 1.0 - params.dropout) cell_enc = [] cell_dec = [] for _ in range(params.num_hidden_layers): if params.rnn_cell == "LSTMCell": cell_e = tf.nn.rnn_cell.BasicLSTMCell(params.hidden_size) cell_d = tf.nn.rnn_cell.BasicLSTMCell(params.hidden_size) elif params.rnn_cell == "GRUCell": cell_e = tf.nn.rnn_cell.GRUCell(params.hidden_size) cell_d = tf.nn.rnn_cell.GRUCell(params.hidden_size) else: raise ValueError("%s not supported" % params.rnn_cell) cell_e = tf.nn.rnn_cell.DropoutWrapper( cell_e, input_keep_prob=1.0 - params.dropout, output_keep_prob=1.0 - params.dropout, state_keep_prob=1.0 - params.dropout, variational_recurrent=params.use_variational_dropout, input_size=params.embedding_size, dtype=tf.float32) cell_d = tf.nn.rnn_cell.DropoutWrapper( cell_d, input_keep_prob=1.0 - params.dropout, output_keep_prob=1.0 - params.dropout, state_keep_prob=1.0 - params.dropout, variational_recurrent=params.use_variational_dropout, input_size=params.embedding_size, dtype=tf.float32) if params.use_residual: cell_e = tf.nn.rnn_cell.ResidualWrapper(cell_e) cell_d = tf.nn.rnn_cell.ResidualWrapper(cell_d) cell_enc.append(cell_e) cell_dec.append(cell_d) cell_enc = tf.nn.rnn_cell.MultiRNNCell(cell_enc) cell_dec = tf.nn.rnn_cell.MultiRNNCell(cell_dec) with tf.variable_scope("encoder"): _, final_state = tf.nn.dynamic_rnn(cell_enc, src_inputs, features["source_length"], dtype=tf.float32) # Shift left shifted_tgt_inputs = tf.pad(tgt_inputs, [[0, 0], [1, 0], [0, 0]]) shifted_tgt_inputs = shifted_tgt_inputs[:, :-1, :] with tf.variable_scope("decoder"): outputs, _ = tf.nn.dynamic_rnn(cell_dec, shifted_tgt_inputs, features["target_length"], initial_state=final_state) if params.dropout: outputs = tf.nn.dropout(outputs, 1.0 - params.dropout) if labels is None: # Prediction logits = layers.nn.linear(outputs[:, -1, :], tgt_vocab_size, True, scope="softmax") return logits # Prediction logits = layers.nn.linear(outputs, tgt_vocab_size, True, scope="softmax") logits = tf.reshape(logits, [-1, tgt_vocab_size]) ce = layers.nn.smoothed_softmax_cross_entropy_with_logits( logits=logits, labels=labels, smoothing=params.label_smoothing, normalize=True) ce = tf.reshape(ce, tf.shape(labels)) tgt_mask = tf.to_float( tf.sequence_mask(features["target_length"], maxlen=tf.shape(features["target"])[1])) loss = get_loss(features, params, ce, tgt_mask) return loss
def model_graph(features, labels, params): src_vocab_size = len(params.vocabulary["source"]) tgt_vocab_size = len(params.vocabulary["target"]) with tf.variable_scope("source_embedding"): src_emb = tf.get_variable("embedding", [src_vocab_size, params.embedding_size]) src_bias = tf.get_variable("bias", [params.embedding_size]) src_inputs = tf.nn.embedding_lookup(src_emb, features["source"]) with tf.variable_scope("target_embedding"): tgt_emb = tf.get_variable("embedding", [tgt_vocab_size, params.embedding_size]) tgt_bias = tf.get_variable("bias", [params.embedding_size]) tgt_inputs = tf.nn.embedding_lookup(tgt_emb, features["target"]) src_inputs = tf.nn.bias_add(src_inputs, src_bias) tgt_inputs = tf.nn.bias_add(tgt_inputs, tgt_bias) if params.dropout and not params.use_variational_dropout: src_inputs = tf.nn.dropout(src_inputs, 1.0 - params.dropout) tgt_inputs = tf.nn.dropout(tgt_inputs, 1.0 - params.dropout) # encoder cell_fw = layers.rnn_cell.LegacyGRUCell(params.hidden_size) cell_bw = layers.rnn_cell.LegacyGRUCell(params.hidden_size) if params.use_variational_dropout: cell_fw = tf.nn.rnn_cell.DropoutWrapper( cell_fw, input_keep_prob=1.0 - params.dropout, output_keep_prob=1.0 - params.dropout, state_keep_prob=1.0 - params.dropout, variational_recurrent=True, input_size=params.embedding_size, dtype=tf.float32) cell_bw = tf.nn.rnn_cell.DropoutWrapper( cell_bw, input_keep_prob=1.0 - params.dropout, output_keep_prob=1.0 - params.dropout, state_keep_prob=1.0 - params.dropout, variational_recurrent=True, input_size=params.embedding_size, dtype=tf.float32) encoder_output = _encoder(cell_fw, cell_bw, src_inputs, features["source_length"]) # decoder cell = layers.rnn_cell.LegacyGRUCell(params.hidden_size) if params.use_variational_dropout: cell = tf.nn.rnn_cell.DropoutWrapper( cell, input_keep_prob=1.0 - params.dropout, output_keep_prob=1.0 - params.dropout, state_keep_prob=1.0 - params.dropout, variational_recurrent=True, # input + context input_size=params.embedding_size + 2 * params.hidden_size, dtype=tf.float32) length = { "source": features["source_length"], "target": features["target_length"] } initial_state = encoder_output["final_states"]["backward"] decoder_output = _decoder(cell, tgt_inputs, encoder_output["annotation"], length, initial_state) # Shift left shifted_tgt_inputs = tf.pad(tgt_inputs, [[0, 0], [1, 0], [0, 0]]) shifted_tgt_inputs = shifted_tgt_inputs[:, :-1, :] all_outputs = tf.concat([ tf.expand_dims(decoder_output["initial_state"], axis=1), decoder_output["outputs"], ], axis=1) shifted_outputs = all_outputs[:, :-1, :] maxout_features = [ shifted_tgt_inputs, shifted_outputs, decoder_output["values"] ] maxout_size = params.hidden_size // params.maxnum if labels is None: # Special case for non-incremental decoding maxout_features = [ shifted_tgt_inputs[:, -1, :], shifted_outputs[:, -1, :], decoder_output["values"][:, -1, :] ] maxhid = layers.nn.maxout(maxout_features, maxout_size, params.maxnum, concat=False) readout = layers.nn.linear(maxhid, params.embedding_size, False, False, scope="deepout") # Prediction logits = layers.nn.linear(readout, tgt_vocab_size, True, False, scope="softmax") return logits maxhid = layers.nn.maxout(maxout_features, maxout_size, params.maxnum, concat=False) readout = layers.nn.linear(maxhid, params.embedding_size, False, False, scope="deepout") if params.dropout and not params.use_variational_dropout: readout = tf.nn.dropout(readout, 1.0 - params.dropout) # Prediction logits = layers.nn.linear(readout, tgt_vocab_size, True, False, scope="softmax") logits = tf.reshape(logits, [-1, tgt_vocab_size]) ce = layers.nn.smoothed_softmax_cross_entropy_with_logits( logits=logits, labels=labels, smoothing=params.label_smoothing, normalize=True) ce = tf.reshape(ce, tf.shape(labels)) tgt_mask = tf.to_float( tf.sequence_mask(features["target_length"], maxlen=tf.shape(features["target"])[1])) # loss = tf.reduce_sum(ce * tgt_mask) / tf.reduce_sum(tgt_mask) loss = get_loss(features, params, ce, tgt_mask) return loss