def sample(news_config: GroverConfig, initial_context, eos_token, min_len, ignore_ids=None, p_for_topp=0.95, do_topk=False): """ V1 version of: sample outputs from a model, and do it all at once :param news_config: Configuration used to construct the model :param initial_context: [batch_size, seq_length] that we'll start generating with :param eos_token: Stop generating if you see this (tf scalar) :param min_len: min length of sample :param ignore_ids: NEVER GENERATE THESE [vocab_size] :return: """ batch_size, _ = get_shape_list(initial_context, expected_rank=2) if ignore_ids is None: ignore_ids = tf.constant([x == 0 for x in range(news_config.vocab_size)], dtype=tf.bool) with tf.name_scope('sample_sequence'): # Initial call to get cache context_output = initialize_from_context(initial_context, ignore_ids=ignore_ids, news_config=news_config, p_for_topp=p_for_topp, do_topk=do_topk) ctx = context_output['tokens'] cache = context_output['cache'] probs = context_output['probs'] def body(ctx, cache, probs): """ for whatever reason this didn't work when I ran it on more than one at once... ugh.""" next_outputs = sample_step(ctx[:, -1][:, None], ignore_ids=ignore_ids, news_config=news_config, batch_size=batch_size, p_for_topp=p_for_topp, cache=cache, do_topk=do_topk) # Update everything new_cache = tf.concat([cache, next_outputs['new_cache']], axis=-2) new_ids = tf.concat([ctx, next_outputs['new_tokens'][:, None]], axis=1) new_probs = tf.concat([probs, next_outputs['new_probs'][:, None]], axis=1) return [new_ids, new_cache, new_probs] def cond(ctx, cache, probs): # ctx = tf.Print(ctx,[tf.shape(ctx)]) is_eos = tf.reduce_all(tf.reduce_any(tf.equal(ctx[:,-1:], eos_token), axis=1)) is_len = tf.greater(get_shape_list(ctx)[1], min_len) return tf.logical_not(tf.logical_and(is_eos, is_len)) tokens, cache, probs = tf.while_loop( cond=cond, body=body, maximum_iterations=1025 - get_shape_list(ctx)[1], loop_vars=[ctx, cache, probs], shape_invariants=[tf.TensorShape([batch_size, None]), tf.TensorShape( [batch_size, news_config.num_hidden_layers, 2, news_config.num_attention_heads, None, news_config.hidden_size // news_config.num_attention_heads]), tf.TensorShape([batch_size, None]), ], back_prop=False, ) return tokens, probs
def residual_mlp_layer(x_flat, intermediate_size, initializer_range=0.02, hidden_dropout_prob=0.1): """ :param x: The attention output. It should be [batch_size*seq_length, dim] :param intermediate_size: the hidden projection. By default this is the input_dim * 4. in the original GPT we would return layer_norm(x_norm + h1) rather than layer_norm(x + h1) :return: """ batch_size_seq_length, hidden_size = get_shape_list(x_flat, expected_rank=2) x_norm = layer_norm(x_flat, name='mlp_ln0') intermediate_output = tf.layers.dense( x_norm, intermediate_size, activation=gelu, kernel_initializer=create_initializer(initializer_range), name='intermediate', ) output_for_residual = tf.layers.dense( intermediate_output, hidden_size, name='output', kernel_initializer=create_initializer(initializer_range)) output_for_residual = dropout(output_for_residual, hidden_dropout_prob) layer_output = layer_norm(x_flat + output_for_residual, name='mlp_ln1') return layer_output
def _top_k_sample(logits, ignore_ids=None, num_samples=1, k=10): """ Does top-k sampling. if ignore_ids is on, then we will zero out those logits. :param logits: [batch_size, vocab_size] tensor :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict, like padding maybe :param p: topp threshold to use, either a float or a [batch_size] vector :return: [batch_size, num_samples] samples # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK """ with tf.variable_scope('top_p_sample'): batch_size, vocab_size = get_shape_list(logits, expected_rank=2) probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10, axis=-1) # [batch_size, vocab_perm] indices = tf.argsort(probs, direction='DESCENDING') # find the top pth index to cut off. careful we don't want to cutoff everything! # result will be [batch_size, vocab_perm] k_expanded = k if isinstance(k, int) else k[:, None] exclude_mask = tf.range(vocab_size)[None] >= k_expanded # OPTION A - sample in the sorted space, then unsort. logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10 sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples) sample = tf.batch_gather(indices, sample_perm) return { 'probs': probs, 'sample': sample, }
def initialize_from_context(initial_context, ignore_ids, news_config, p_for_topp=0.95, do_topk=False): """ same signature as sample_step""" batch_size, _ = get_shape_list(initial_context, expected_rank=2) context_output = sample_step(tokens=initial_context, ignore_ids=ignore_ids, news_config=news_config, batch_size=batch_size, p_for_topp=p_for_topp, cache=None, do_topk=do_topk) return { 'tokens': tf.concat([initial_context, context_output['new_tokens'][:, None]], 1), 'cache': context_output['new_cache'], 'probs': context_output['new_probs'][:, None] }
def _top_p_sample(logits, ignore_ids=None, num_samples=1, p=0.9): """ Does top-p sampling. if ignore_ids is on, then we will zero out those logits. :param logits: [batch_size, vocab_size] tensor :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict, like padding maybe :param p: topp threshold to use, either a float or a [batch_size] vector :return: [batch_size, num_samples] samples # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK """ with tf.variable_scope('top_p_sample'): batch_size, vocab_size = get_shape_list(logits, expected_rank=2) probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10, axis=-1) if isinstance(p, float) and p > 0.999999: # Don't do top-p sampling in this case print("Top-p sampling DISABLED", flush=True) return { 'probs': probs, 'sample': tf.random.categorical( logits=logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10, num_samples=num_samples, dtype=tf.int32), } # [batch_size, vocab_perm] indices = tf.argsort(probs, direction='DESCENDING') cumulative_probabilities = tf.math.cumsum(tf.batch_gather(probs, indices), axis=-1, exclusive=False) # find the top pth index to cut off. careful we don't want to cutoff everything! # result will be [batch_size, vocab_perm] p_expanded = p if isinstance(p, float) else p[:, None] exclude_mask = tf.logical_not( tf.logical_or(cumulative_probabilities < p_expanded, tf.range(vocab_size)[None] < 1)) # OPTION A - sample in the sorted space, then unsort. logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10 sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples) sample = tf.batch_gather(indices, sample_perm) # OPTION B - unsort first - Indices need to go back to 0 -> N-1 -- then sample # unperm_indices = tf.argsort(indices, direction='ASCENDING') # include_mask_unperm = tf.batch_gather(include_mask, unperm_indices) # logits_to_use = logits - (1 - tf.cast(include_mask_unperm, tf.float32)) * 1e10 # sample = tf.random.categorical(logits=logits_to_use, num_samples=num_samples, dtype=tf.int32) return { 'probs': probs, 'sample': sample, }
def sample_step(tokens, ignore_ids, news_config, batch_size=1, p_for_topp=0.95, cache=None, do_topk=False): """ Helper function that samples from grover for a single step :param tokens: [batch_size, n_ctx_b] tokens that we will predict from :param ignore_ids: [n_vocab] mask of the tokens we don't want to predict :param news_config: config for the GroverModel :param batch_size: batch size to use :param p_for_topp: top-p or top-k threshold :param cache: [batch_size, news_config.num_hidden_layers, 2, news_config.num_attention_heads, n_ctx_a, news_config.hidden_size // news_config.num_attention_heads] OR, None :return: new_tokens, size [batch_size] new_probs, also size [batch_size] new_cache, size [batch_size, news_config.num_hidden_layers, 2, n_ctx_b, news_config.num_attention_heads, news_config.hidden_size // news_config.num_attention_heads] """ model = GroverModel( config=news_config, is_training=False, input_ids=tokens, reuse=tf.AUTO_REUSE, scope='newslm', chop_off_last_token=False, do_cache=True, cache=cache, ) # Extract the FINAL SEQ LENGTH batch_size_times_seq_length, vocab_size = get_shape_list(model.logits_flat, expected_rank=2) next_logits = tf.reshape(model.logits_flat, [batch_size, -1, vocab_size])[:, -1] if do_topk: sample_info = _top_k_sample(next_logits, num_samples=1, k=tf.cast(p_for_topp, dtype=tf.int32)) else: sample_info = _top_p_sample(next_logits, ignore_ids=ignore_ids, num_samples=1, p=p_for_topp) new_tokens = tf.squeeze(sample_info['sample'], 1) new_probs = tf.squeeze(tf.batch_gather(sample_info['probs'], sample_info['sample']), 1) return { 'new_tokens': new_tokens, 'new_probs': new_probs, 'new_cache': model.new_kvs, }
def _attention_projection_and_transpose(x_flat, batch_size, seq_length, num_attention_heads, size_per_head, name, initializer_range=0.02): """ :param x_flat: [batch_size*seq_length, width] :return: A fixed up tensor of size [batch_size, num_attention_heads, seq_length, size_per_head] """ batch_size_seq_length, dim = get_shape_list(x_flat, expected_rank=2) if dim != size_per_head * num_attention_heads: raise ValueError("passed in a tensor of shape {} when size_per_head={} and num_attention_heads={}".format( (batch_size_seq_length, dim), size_per_head, num_attention_heads )) projected = tf.layers.dense( x_flat, num_attention_heads * size_per_head, name=name, kernel_initializer=create_initializer(initializer_range)) projected = tf.reshape( projected, [batch_size, seq_length, num_attention_heads, size_per_head]) output_tensor = tf.transpose(projected, [0, 2, 1, 3]) return output_tensor
def cond(ctx, cache, probs): # ctx = tf.Print(ctx,[tf.shape(ctx)]) is_eos = tf.reduce_all(tf.reduce_any(tf.equal(ctx[:,-1:], eos_token), axis=1)) is_len = tf.greater(get_shape_list(ctx)[1], min_len) return tf.logical_not(tf.logical_and(is_eos, is_len))
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = GroverModel( config=config, is_training=is_training, input_ids=input_ids, pad_token_id=config.pad_token_id, chop_off_last_token=True, ) total_loss = model.lm_loss() if is_training: train_op, train_metrics = optimization_adafactor.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) tvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) else: train_op = None train_metrics = {} tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = get_assignment_map_from_checkpoint(tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, host_call=construct_scalar_host_call(metric_dict=train_metrics, model_dir=params['model_dir'], prefix='training/'), scaffold_fn=scaffold_fn) else: output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, training_hooks=[ tf.train.LoggingTensorHook({'loss': tf.metrics.mean(total_loss)[1]}, every_n_iter=100)], scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(loss): """Evaluation metric Fn which runs on CPU.""" perplexity = tf.exp(tf.reduce_mean(loss)) bpc = tf.reduce_mean(loss) / tf.constant(math.log(2)) return { "perplexity": tf.metrics.mean(perplexity), "bpc": tf.metrics.mean(bpc), } if use_tpu: with tf.colocate_with(total_loss): total_loss = tf.contrib.tpu.cross_replica_sum(total_loss) \ / num_tpu_cores metric_loss = tf.tile(tf.reshape(total_loss, [1, 1]), [eval_batch_size, 1]) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=(metric_fn, [metric_loss]), scaffold_fn=scaffold_fn) else: gt_logprobs = tf.squeeze(tf.batch_gather(model.log_probs, model.target_ids[:, :, None]), axis=2) # Need top-p required under topp sampling! better_than_gt = model.log_probs > gt_logprobs[:, :, None] top_p_required = tf.reduce_sum(tf.cast(better_than_gt, tf.float32) * tf.exp(model.log_probs), axis=2) # No top-p sampling for now, since this seems to be too slow on TPUs if use_tpu: predictions = tf.reshape( tf.random.categorical(logits=model.logits_flat, num_samples=1), get_shape_list(model.target_ids), ) else: # Argmax # predictions = tf.math.argmax(model.log_probs, axis=-1, output_type=tf.int32) predictions = tf.reshape( _top_p_sample(model.logits_flat, num_samples=1, p=0.99)['sample'], get_shape_list(model.target_ids), ) pred_logprobs = tf.squeeze(tf.batch_gather(model.log_probs, predictions[:, :, None]), axis=2) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, predictions={'gt_logprobs': gt_logprobs, 'top_p_required': top_p_required, 'predictions': predictions, 'pred_logprobs': pred_logprobs, 'labels': input_ids}, scaffold_fn=scaffold_fn) return output_spec
def __init__(self, config: GroverConfig, is_training, input_ids, cache=None, do_cache=False, pad_token_id=0, chop_off_last_token=True, scope=None, reuse=False): """ :param config: :param is_training: :param input_ids: Tensor thats of size [batch_size, seq_length] :param cache: Optionally, a tensor to use that will contain cached information of the size [batch_size, num_layers, 2, num_heads, cache_length, features] :param do_cache: Whether to cache again. :param pad_token_id: Which token will be used for padding (probably 0.) :param chop_off_last_token: True if we will end up using this for TRAINING only. False if we want to generate. it means the last token in input_ids will not be processed by the model as input :param scope: scope to run this on """ self.config = copy.deepcopy(config) self.is_training = is_training self.pad_token_id = pad_token_id if not is_training: self.config.hidden_dropout_prob = 0.0 self.config.attention_probs_dropout_prob = 0.0 if chop_off_last_token: self.target_ids = input_ids[:, 1:] self.input_ids = input_ids[:, :-1] else: self.input_ids = input_ids self.target_ids = tf.concat((input_ids[:, 1:], tf.constant(self.pad_token_id, dtype=self.input_ids.dtype, shape=[get_shape_list(self.input_ids, 2)[0], 1])), 1) self.batch_size, self.seq_length = get_shape_list(self.input_ids, 2) if cache is None: caches = [None] * config.num_hidden_layers self.cache_length = 0 else: batch_size_, num_layers_, two_, num_heads_, self.cache_length, features_ = get_shape_list( cache, expected_rank=6) assert batch_size_ == self.batch_size assert num_layers_ == config.num_hidden_layers assert two_ == 2 assert num_heads_ == config.num_attention_heads assert features_ == (config.hidden_size // config.num_attention_heads) caches = tf.unstack(cache, axis=1) with tf.variable_scope(scope, default_name='newslm', reuse=reuse): with tf.variable_scope("embeddings"): embeddings, self.embedding_table = embed(self.input_ids, config.vocab_size, config.hidden_size, position_offset=self.cache_length, initializer_range=config.initializer_range, max_position_embeddings=config.max_position_embeddings, use_one_hot_embeddings=True) mask = get_attention_mask(self.seq_length, self.seq_length + self.cache_length, dtype=embeddings.dtype) # We keep the representation as a 2D tensor to avoid re-shaping it back and # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on # the GPU/CPU but may not be free on the TPU, so we want to minimize them to # help the optimizer. hidden_state = tf.reshape(embeddings, [self.batch_size * self.seq_length, self.config.hidden_size]) new_kvs = [] for layer_idx, layer_cache in enumerate(caches): with tf.variable_scope('layer{:02d}'.format(layer_idx)): # [batch_size * seq_length, hidden_size] attention_output, new_kv = attention_layer( hidden_state, mask, batch_size=self.batch_size, seq_length=self.seq_length, size_per_head=config.hidden_size // config.num_attention_heads, num_attention_heads=config.num_attention_heads, initializer_range=config.initializer_range, hidden_dropout_prob=self.config.hidden_dropout_prob, attention_probs_dropout_prob=self.config.attention_probs_dropout_prob, do_cache=do_cache, cache=layer_cache, ) new_kvs.append(new_kv) # [batch_size * seq_length, hidden_size] hidden_state = residual_mlp_layer(hidden_state + attention_output, intermediate_size=config.intermediate_size, hidden_dropout_prob=self.config.hidden_dropout_prob) self.hidden_state = hidden_state self.new_kvs = tf.stack(new_kvs, axis=1) if do_cache else None # Note that the hidden state is still flat (batch_size*hidden_size) self.logits_flat = tf.matmul(self.hidden_state, self.embedding_table, transpose_b=True)
def embed(input_ids, vocab_size, embedding_size, position_offset=0, initializer_range=0.02, max_position_embeddings=512, use_one_hot_embeddings=True): """reur and position embeddings :param input_ids: int Tensor of shape [batch_size, seq_length]. :param vocab_size: number of words in vocab :param embedding_size: dimensionality of the embedding :param position_offset: aka number of cached tokens. :param initializer_range: float. Range of the weight initialization. :param max_position_embeddings: int. Maximum sequence length. :param use_one_hot_embeddings: probably want this to be true :return: [batch_size, seq_length, embedding_size] embedded tensor """ (batch_size, seq_length) = get_shape_list(input_ids, expected_rank=2) embedding_table = tf.get_variable( name='word_embed', shape=[vocab_size, embedding_size], initializer=create_initializer(initializer_range), ) assert_op = tf.assert_less_equal(tf.reduce_max(input_ids), vocab_size - 1) with tf.control_dependencies([assert_op]): if use_one_hot_embeddings: flat_input_ids = tf.reshape(input_ids, [-1]) one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) output_flat = tf.matmul(one_hot_input_ids, embedding_table) else: output_flat = tf.nn.embedding_lookup(embedding_table, input_ids) embedded_input = tf.reshape(output_flat, [batch_size, seq_length, embedding_size]) assert_op = tf.assert_less_equal(seq_length, max_position_embeddings) with tf.control_dependencies([assert_op]): full_position_embeddings = tf.get_variable( name='pos_embed', shape=[max_position_embeddings, embedding_size], initializer=create_initializer(initializer_range), ) # Since the position embedding table is a learned variable, we create it # using a (long) sequence length `max_position_embeddings`. The actual # sequence length might be shorter than this, for faster training of # tasks that do not have long sequences. # # So `full_position_embeddings` is effectively an embedding table # for position [0, 1, 2, ..., max_position_embeddings-1], and the current # sequence has positions [0, 1, 2, ... seq_length-1], so we can just # perform a slice. if position_offset == 0: embedded_input += tf.slice(full_position_embeddings, [0, 0], [seq_length, -1])[None] else: # Tensorflow is too stupid to allow slicing flat_pos_ids = (tf.range(seq_length, dtype=tf.int32) + position_offset) one_hot_pos_ids = tf.one_hot(flat_pos_ids, depth=max_position_embeddings) # [seq_length, full_position_embeddings], [full_position_embeddings, dim] seq_embeds = tf.matmul(one_hot_pos_ids, full_position_embeddings) embedded_input += seq_embeds[None] # embedded_input += tf.slice(full_position_embeddings[position_offset:], [0, 0], [seq_length, -1])[None] return layer_norm(embedded_input, name='embed_norm'), embedding_table
def attention_layer(x_flat, attention_mask, batch_size, seq_length, size_per_head=512, num_attention_heads=1, *, cache=None, initializer_range=0.02, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, do_cache=False): """ :param x_flat: Tensor input, should be [batch_size*seq_length, dim] :param attention_mask: Attention mask to use of size [seq_length, seq_length+cached_length] :param size_per_head: dim = size_per_head * num_attention_heads :param num_attention_heads: dim = size_per_head * num_attention_heads :param cache: Optionally some past (cached) things of size [batch, 2, heads, sequence, features], where 2 is [k, v] :param do_cache: True if we should return cache :return: A new tensor of shape [batch_size, seq_length, dim] as well as a new cache "cached_keys_and_values" that will be of size [batch_size, 2, num_attention_heads, seq_length, dim] """ batch_size_seq_length, dim = get_shape_list(x_flat, expected_rank=2) if dim != size_per_head * num_attention_heads: raise ValueError("passed in a tensor of shape {} when size_per_head={} and num_attention_heads={}".format( (batch_size_seq_length, dim), size_per_head, num_attention_heads )) query = _attention_projection_and_transpose(x_flat, batch_size=batch_size, seq_length=seq_length, num_attention_heads=num_attention_heads, size_per_head=size_per_head, name='query_layer', initializer_range=initializer_range) key = _attention_projection_and_transpose(x_flat, batch_size=batch_size, seq_length=seq_length, num_attention_heads=num_attention_heads, size_per_head=size_per_head, name='key_layer', initializer_range=initializer_range) value = _attention_projection_and_transpose(x_flat, batch_size=batch_size, seq_length=seq_length, num_attention_heads=num_attention_heads, size_per_head=size_per_head, name='value_layer', initializer_range=initializer_range) # Add to cache cached_keys_and_values = tf.stack([key, value], axis=1) if do_cache else None # Things that were relevant from the cache if cache is not None: pk, pv = tf.unstack(cache, axis=1) key = tf.concat([pk, key], axis=-2) value = tf.concat([pv, value], axis=-2) # Multiply [batch_size, num_attention_heads, seq_length, size_per_head] with # [batch_size, num_attention_heads, size_per_head, seq_length+cached_length] -> # [batch_size, num_attention_heads, seq_length, seq_length+cached_length] attention_scores = tf.matmul(query, key, transpose_b=True) attention_scores = tf.multiply(attention_scores, 1.0 / math.sqrt(float(size_per_head))) attention_scores = mask_attention_for_ltr(attention_scores, attention_mask) attention_probs = tf.nn.softmax(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. # NOPENOPENOPENOPE # attention_probs = factoreddropout(attention_probs, attention_probs_dropout_prob) # Multiply [batch_size, num_attention_heads, seq_length, seq_length+cached_length] with # [batch_size, num_attention_heads, seq_length+cached_length, size_per_head] -> # [batch_size, num_attention_heads, seq_length, size_per_head] -> context_layer = tf.matmul(attention_probs, value) # `context_layer` = [batch_size, seq_length, num_attention_heads, size_per_head] context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) context_layer = tf.reshape(context_layer, [batch_size * seq_length, num_attention_heads * size_per_head]) context_layer_projected = tf.layers.dense( context_layer, num_attention_heads * size_per_head, kernel_initializer=create_initializer(initializer_range), name='context_projection_layer' ) context_layer_projected = dropout(context_layer_projected, hidden_dropout_prob) return context_layer_projected, cached_keys_and_values
def apply_gradients(self, grads_and_vars, global_step=None, name=None): """See base class.""" assignments = [] for (grad, param) in grads_and_vars: if grad is None or param is None: continue param_name = self._get_variable_name(param.name) shape_list = get_shape_list(param, expected_rank=[1, 2]) # decay_rate = 1 - tf.pow(tf.cast(tf.train.get_or_create_global_step(), tf.float32) + 1.0, -0.8) decay_rate = self.beta_2 grad_squared = tf.square(grad) + self.epsilon1 update_scale = self.learning_rate # update_scale = self.learning_rate * tf.cast(self._parameter_scale(param), dtype=tf.float32) # HACK: Make things dependent on grad. # This confounds the XLA rewriter and keeps it from fusing computations # across different variables. This fusion is a bad for HBM usage, since # it causes the gradients to persist in memory. grad_squared_mean = tf.reduce_mean(grad_squared) decay_rate += grad_squared_mean * 1e-30 update_scale += grad_squared_mean * 1e-30 # END HACK if self._use_factored(shape_list): num_rows, num_columns = shape_list vr = tf.get_variable( name=param_name + "/adafactor_vr", shape=[num_rows], dtype=tf.float32, trainable=False, initializer=tf.zeros_initializer()) vc = tf.get_variable( name=param_name + "/adafactor_vc", shape=[num_columns], dtype=tf.float32, trainable=False, initializer=tf.zeros_initializer()) next_vr = decay_rate * vr + (1 - decay_rate) * tf.reduce_mean(grad_squared, 1) next_vc = decay_rate * vc + (1 - decay_rate) * tf.reduce_mean(grad_squared, 0) long_term_mean = tf.reduce_mean(next_vr, -1, keepdims=True) r_factor = tf.rsqrt(next_vr / long_term_mean + self.epsilon1) c_factor = tf.rsqrt(next_vc + self.epsilon1) update = grad * tf.expand_dims(r_factor, -1) * tf.expand_dims(c_factor, -2) assignments.append(vr.assign(next_vr, use_locking=self.use_locking)) assignments.append(vc.assign(next_vc, use_locking=self.use_locking)) else: v = tf.get_variable( name=param_name + "/adafactor_v", shape=shape_list, dtype=tf.float32, trainable=False, initializer=tf.zeros_initializer()) next_v = decay_rate * v + (1 - decay_rate) * grad_squared assignments.append(v.assign(next_v, use_locking=self.use_locking)) update = grad * tf.rsqrt(next_v + self.epsilon1) clipping_denom = tf.maximum(1.0, reduce_rms(update) / self.clipping_rate) update /= clipping_denom # Do weight decay # Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want ot decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # # of the weights to the loss with plain (non-momentum) SGD. if self._do_use_weight_decay(param_name): update += self.weight_decay_rate * param update_with_lr = update_scale * update next_param = param - update_with_lr assignments.append(param.assign(next_param, use_locking=self.use_locking)) return tf.group(*assignments, name=name)