class AbstractRecurrentEstimator(AbstractIcecapsEstimator): @classmethod def construct_expected_params(cls): expected_params = super().construct_expected_params() expected_params["max_length"] = cls.make_param(50) expected_params["cell_type"] = cls.make_param('gru') expected_params["hidden_units"] = cls.make_param(32) expected_params["depth"] = cls.make_param(1) expected_params["token_embed_dim"] = cls.make_param(16) expected_params["tie_token_embeddings"] = cls.make_param(True) expected_params["beam_width"] = cls.make_param(8) expected_params["vocab_file"] = cls.make_param( "./dummy_data/vocab.dic") expected_params["vocab_size"] = cls.make_param(0) expected_params["skip_tokens"] = cls.make_param('') expected_params["skip_tokens_start"] = cls.make_param('') return expected_params def extract_args(self, features, mode, params): super().extract_args(features, mode, params) if self.hparams.vocab_size > 0: self.vocab = Vocabulary(size=self.hparams.vocab_size) else: self.vocab = Vocabulary( fname=self.hparams.vocab_file, skip_tokens=self.hparams.skip_tokens, skip_tokens_start=self.hparams.skip_tokens_start) def build_cell(self, name=None): if self.hparams.cell_type == 'linear': cell = BasicRNNCell(self.hparams.hidden_units, activation=tf.identity, name=name) elif self.hparams.cell_type == 'tanh': cell = BasicRNNCell(self.hparams.hidden_units, activation=tf.tanh, name=name) elif self.hparams.cell_type == 'relu': cell = BasicRNNCell(self.hparams.hidden_units, activation=tf.nn.relu, name=name) elif self.hparams.cell_type == 'gru': cell = GRUCell(self.hparams.hidden_units, name=name) elif self.hparams.cell_type == 'lstm': cell = LSTMCell(self.hparams.hidden_units, name=name) else: raise ValueError('Provided cell type not supported.') return cell def build_deep_cell(self, cell_list=None, name=None, return_raw_list=False): if name is None: name = "cell" if cell_list is None: cell_list = [] for i in range(self.hparams.depth): cell = self.build_cell(name=name + "_" + str(i)) cell = DropoutWrapper(cell, output_keep_prob=self.keep_prob) cell_list.append(cell) if return_raw_list: return cell_list if len(cell_list) == 1: return cell_list[0] return MultiRNNCell(cell_list) def build_rnn(self, input_key="inputs"): with tf.variable_scope('rnn'): self.cell = self.build_deep_cell() self.build_inputs(input_key) self.outputs, self.last_state = tf.nn.dynamic_rnn( cell=self.cell, inputs=self.inputs_dense, sequence_length=self.inputs_length, time_major=False, dtype=tf.float32 ) # [batch_size, max_time_step, cell_output_size], [batch_size, cell_output_size] def build_embeddings(self): if "token_embeddings" in self.features and self.hparams.tie_token_embeddings: self.token_embeddings = self.features["token_embeddings"] else: self.token_embeddings = tf.get_variable( name='token_embeddings', shape=[self.vocab.size(), self.hparams.token_embed_dim]) if self.hparams.token_embed_dim != self.hparams.hidden_units: projection = tf.get_variable(name='token_embed_proj', shape=[ self.hparams.token_embed_dim, self.hparams.hidden_units ]) self.token_embeddings = self.token_embeddings @ projection def embed_sparse_to_dense(self, sparse): with tf.variable_scope('embed_sparse_to_dense', reuse=tf.AUTO_REUSE): dense = tf.nn.embedding_lookup(self.token_embeddings, sparse) return dense def build_inputs(self, input_key): self.build_embeddings() self.inputs_sparse_untrimmed = tf.cast(self.features[input_key], tf.int32) self.inputs_length = tf.cast( tf.count_nonzero( self.inputs_sparse_untrimmed - self.vocab.end_token_id, -1), tf.int32) self.inputs_max_length = tf.reduce_max(self.inputs_length) self.inputs_sparse = tf.slice(self.inputs_sparse_untrimmed, [0, 0], [-1, self.inputs_max_length]) self.inputs_dense = self.embed_sparse_to_dense(self.inputs_sparse) self.batch_size = tf.shape(self.inputs_sparse)[0] def build_loss(self): with tf.name_scope('build_loss'): self.loss = seq2seq.sequence_loss( logits=self.logits, targets=self.targets_sparse, weights=self.target_mask, average_across_timesteps=True, average_across_batch=True, ) self.reported_loss = tf.identity(self.loss, 'reported_loss')
class AbstractTransformerEstimator(AbstractIcecapsEstimator): @classmethod def construct_expected_params(cls): expected_params = super().construct_expected_params() expected_params["vocab_file"] = cls.make_param( "icecaps/examples/dummy_data/vocab.dic") expected_params["vocab_size"] = cls.make_param(0) expected_params["depth"] = cls.make_param(1) expected_params["num_heads"] = cls.make_param(8) expected_params["d_model"] = cls.make_param(32) expected_params["d_pos"] = cls.make_param(32) expected_params["d_ff"] = cls.make_param(64) expected_params["max_length"] = cls.make_param(10) expected_params["min_wavelength"] = cls.make_param(1.0) expected_params["max_wavelength"] = cls.make_param(1000.0) expected_params["warmup_steps"] = cls.make_param(4000.0) expected_params["fixed_learning_rate"] = cls.make_param(False) expected_params["learn_wavelengths"] = cls.make_param(False) expected_params["modality"] = cls.make_param("seq") expected_params["tree_depth"] = cls.make_param(256) expected_params["tree_width"] = cls.make_param(2) expected_params["learn_positional_embeddings"] = cls.make_param(False) return expected_params def extract_args(self, features, mode, params): super().extract_args(features, mode, params) self.d_k = self.hparams.d_model // self.hparams.num_heads self.d_pos = self.hparams.d_pos if self.hparams.d_pos == 0 else self.hparams.d_pos self.d_ff = self.hparams.d_ff if self.hparams.d_ff == 0 else self.hparams.d_ff if self.hparams.vocab_size > 0: self.vocab = Vocabulary(size=self.hparams.vocab_size) else: self.vocab = Vocabulary(fname=self.hparams.vocab_file) if not self.hparams.fixed_learning_rate: self.train_step = tf.get_variable( 'train_step', shape=[], dtype=tf.float32, initializer=tf.zeros_initializer(dtype=tf.int32), trainable=False) self.learning_rate = ( # magic formula provided in transformer paper tf.sqrt(1.0 / self.hparams.d_model) * tf.minimum( self.train_step * tf.pow(self.hparams.warmup_steps, -1.5), tf.pow(self.train_step, -0.5))) def build_embeddings(self): # self.hparams.max_length), dtype=tf.float32), 1) position = tf.expand_dims(tf.cast(tf.range(0, 2048), dtype=tf.float32), 1) if self.hparams.learn_wavelengths: wavelength_logs = tf.get_variable("wavelength_logs", [self.d_pos // 2], tf.float32) else: wavelength_logs = tf.linspace( math.log(self.hparams.min_wavelength), math.log(self.hparams.max_wavelength), self.d_pos // 2) div_term = tf.expand_dims(tf.exp(-wavelength_logs), 0) outer_product = tf.matmul(position, div_term) cosines = tf.cos(outer_product) sines = tf.sin(outer_product) self.positional_embeddings = tf.concat([cosines, sines], -1) if self.hparams.learn_positional_embeddings: self.positional_embeddings = tf.get_variable( name='positional_embeddings', shape=[self.hparams.max_length, self.hparams.d_model ]) * np.sqrt(float(self.hparams.d_model)) self.token_embeddings = tf.get_variable( name='token_embeddings', shape=[self.vocab.size(), self.hparams.d_model]) * np.sqrt( float(self.hparams.d_model)) if self.hparams.modality == "tree": self.d_tree_param = self.d_pos // (self.hparams.tree_depth * self.hparams.tree_width) self.tree_params = tf.tanh( tf.get_variable("tree_params", [self.d_tree_param])) self.tiled_tree_params = tf.tile( tf.reshape(self.tree_params, [1, 1, -1]), [self.hparams.tree_depth, self.hparams.tree_width, 1]) self.tiled_depths = tf.tile( tf.reshape(tf.range(self.hparams.tree_depth, dtype=tf.float32), [-1, 1, 1]), [1, self.hparams.tree_width, self.d_tree_param]) self.tree_norm = tf.sqrt( (1 - tf.square(self.tree_params)) * self.hparams.d_model / 2) self.tree_weights = tf.reshape( tf.pow(self.tiled_tree_params, self.tiled_depths) * self.tree_norm, [ self.hparams.tree_depth * self.hparams.tree_width, self.d_tree_param ]) def treeify_positions(self, positions): treeified = tf.expand_dims(positions, -1) * self.tree_weights shape = tf.shape(treeified) shape = tf.concat([shape[:-2], [self.d_pos]], -1) treeified = tf.reshape(treeified, shape) return treeified def init_inputs(self): self.inputs_sparse = tf.cast(self.features["inputs"], tf.int32) self.mask = tf.cast( tf.not_equal(self.inputs_sparse, self.vocab.end_token_id), tf.float32) self.inputs_length = tf.cast(tf.count_nonzero(self.mask, -1), tf.int32) self.inputs_max_length = tf.reduce_max(self.inputs_length) self.batch_size = tf.shape(self.inputs_sparse)[0] self.inputs_sparse = tf.slice( self.inputs_sparse, [0, 0], [self.batch_size, self.inputs_max_length]) self.mask = tf.slice(self.mask, [0, 0], [self.batch_size, self.inputs_max_length]) self.inputs_dense = tf.nn.embedding_lookup( params=self.token_embeddings, ids=self.inputs_sparse) if self.hparams.modality == "seq": self.positions = tf.slice(self.positional_embeddings, [0, 0], [self.inputs_max_length, self.d_pos]) elif self.hparams.modality == "tree": self.positions = tf.reshape(self.features["inputs_positions"], [ self.batch_size, self.inputs_max_length, self.hparams.tree_depth * self.hparams.tree_width ]) self.positions = self.treeify_positions(self.positions) else: raise ValueError("This input modality is not supported.") if self.d_pos != self.hparams.d_model: self.positions = tf.layers.dense(self.positions, self.hparams.d_model) self.inputs_dense = self.inputs_dense + self.positions self.inputs_dense = tf.nn.dropout(self.inputs_dense, self.keep_prob) self.inputs_dense = tf.transpose( tf.transpose(self.inputs_dense) * tf.transpose(self.mask)) def build_layer_norm(self, x): return tf.contrib.layers.layer_norm(x, begin_norm_axis=-1) def build_sublayer_fn(self, x, f): x = self.build_layer_norm(x) x = x + tf.nn.dropout(f(x), self.keep_prob) return x def attention(self, query, key, value, d_k, enc_mask=None, dec_mask=None): scores = tf.matmul(query, tf.transpose(key, [0, 1, 3, 2])) / math.sqrt(d_k) if enc_mask is not None: scores = tf.transpose( scores, [1, 2, 0, 3]) * enc_mask - 1e24 * (1.0 - enc_mask) scores = tf.transpose(scores, [2, 0, 1, 3]) if dec_mask is not None: scores = scores * dec_mask - 1e24 * (1.0 - dec_mask) p_attn = tf.nn.softmax(scores) p_attn = tf.nn.dropout(p_attn, keep_prob=self.keep_prob) attended_values = tf.matmul(p_attn, value) return attended_values, p_attn def mha_fn(self, query, key, value, batch_size, enc_mask_, dec_mask_): with tf.variable_scope("mha", reuse=tf.AUTO_REUSE) as scope: query = tf.transpose( tf.reshape( tf.layers.dense(query, self.hparams.d_model, use_bias=True), [batch_size, -1, self.hparams.num_heads, self.d_k]), [0, 2, 1, 3]) key = tf.transpose( tf.reshape( tf.layers.dense(key, self.hparams.d_model, use_bias=True), [batch_size, -1, self.hparams.num_heads, self.d_k]), [0, 2, 1, 3]) value = tf.transpose( tf.reshape( tf.layers.dense(value, self.hparams.d_model, use_bias=True), [batch_size, -1, self.hparams.num_heads, self.d_k]), [0, 2, 1, 3]) attended, _ = self.attention(query, key, value, self.d_k, enc_mask_, dec_mask_) attended = tf.reshape(tf.transpose(attended, [0, 2, 1, 3]), [batch_size, -1, self.hparams.d_model]) return attended def build_mha_sublayer(self, x, m, batch_size, enc_mask=None, dec_mask=None): with tf.variable_scope("attn", reuse=tf.AUTO_REUSE) as scope: return self.build_sublayer_fn( x, lambda q: tf.layers.dense( self.mha_fn(q, m, m, batch_size, enc_mask, dec_mask), self. hparams.d_model)) def build_ffn_sublayer(self, x, d_ff): with tf.variable_scope("ffn", reuse=tf.AUTO_REUSE) as scope: def ffn_fn(q): return tf.layers.dense(tf.layers.dense(q, d_ff, tf.nn.relu), self.hparams.d_model) return self.build_sublayer_fn(x, ffn_fn) def build_optimizer(self, trainable_params=None): super().build_optimizer(trainable_params) self.step_update_op = tf.assign_add(self.train_step, 1.0) with tf.control_dependencies([self.step_update_op]): self.train_op = tf.group([self.step_update_op, self.train_op])
class RNNEstimator(AbstractRecurrentEstimator): def _model_fn(self, features, mode, params): with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE): self.extract_args(features, mode, params) self.init_inputs() self.build_cell() self.build_obj() if mode == tf.estimator.ModeKeys.PREDICT: self.build_rt_decoder() self.predictions = { "inputs": self.features["inputs"], "outputs": self.hypotheses, "scores": self.scores } if "metadata" in self.features: self.predictions["metadata"] = self.features["metadata"] return tf.estimator.EstimatorSpec(mode, predictions=self.predictions) self.init_targets() self.build_loss() if mode == tf.estimator.ModeKeys.TRAIN: self.build_optimizer() for var in tf.trainable_variables(): # Add histograms for trainable variables tf.summary.histogram(var.op.name, var) return tf.estimator.EstimatorSpec(mode, loss=self.reported_loss, train_op=self.train_op) if mode == tf.estimator.ModeKeys.EVAL: print("Number of parameters: " + str(self.get_num_model_params())) self.eval_metric_ops = dict() return tf.estimator.EstimatorSpec( mode, loss=self.reported_loss, eval_metric_ops=self.eval_metric_ops) @classmethod def construct_expected_params(cls): expected_params = super().construct_expected_params() expected_params["src_vocab_file"] = cls.make_param("") expected_params["tgt_vocab_file"] = cls.make_param("") expected_params["src_vocab_size"] = cls.make_param(0) expected_params["tgt_vocab_size"] = cls.make_param(0) return expected_params def extract_args(self, features, mode, params): super().extract_args(features, mode, params) if (self.hparams.src_vocab_size == 0 and self.hparams.tgt_vocab_size == 0 and self.hparams.src_vocab_file == "" and self.hparams.tgt_vocab_file == ""): self.src_vocab = self.vocab self.tgt_vocab = self.vocab else: if self.hparams.src_vocab_size > 0: self.src_vocab = Vocabulary(size=self.hparams.src_vocab_size) else: self.src_vocab = Vocabulary(fname=self.hparams.src_vocab_file) if self.hparams.tgt_vocab_size > 0: self.tgt_vocab = Vocabulary(size=self.hparams.tgt_vocab_size) else: self.tgt_vocab = Vocabulary(fname=self.hparams.tgt_vocab_file) def init_inputs(self): with tf.name_scope('init_encoder'): inputs = tf.cast(self.features["inputs"], tf.int32) self.batch_size = tf.shape(inputs)[0] inputs_length = tf.cast( tf.count_nonzero(inputs - self.vocab.end_token_id, -1), tf.int32) inputs_max_length = tf.reduce_max(inputs_length) end_token = tf.ones(shape=[ self.batch_size, self.hparams.max_length - inputs_max_length ], dtype=tf.int32) * self.vocab.end_token_id # [batch_size, max_time_steps + 1] self.inputs_sparse = tf.concat([inputs, end_token], axis=1) def init_targets(self): with tf.name_scope('init_decoder'): targets = tf.cast(self.features["targets"], tf.int32) targets_length = tf.cast( tf.count_nonzero(targets - self.vocab.end_token_id, -1), tf.int32) targets_max_length = tf.reduce_max(targets_length) end_token = tf.ones(shape=[ self.batch_size, self.hparams.max_length - targets_max_length ], dtype=tf.int32) * self.vocab.end_token_id # [batch_size, max_time_steps + 1] self.targets_sparse = tf.concat([targets, end_token], axis=1) self.targets_length = targets_length + 1 self.target_mask = tf.sequence_mask(lengths=self.targets_length, maxlen=self.hparams.max_length, dtype=tf.float32) def build_cell(self): sequence_length = tf.ones([self.batch_size], dtype=tf.int32) * self.hparams.max_length super().build_cell(sequence_length, self.src_vocab.size()) def build_obj(self): output_layer = Dense(self.tgt_vocab.size(), name='output_projection') self.logits = output_layer(self.outputs) def build_rt_decoder(self): with tf.name_scope('predict_decoder'): self.hypotheses = tf.argmax(self.logits, -1) self.scores = tf.reduce_sum( tf.reduce_max(tf.nn.log_softmax(self.logits), -1), -1)