def add_timing_signal_1d_given_position(x, position, min_timescale=1.0, max_timescale=1.0e4): """Adds sinusoids of diff frequencies to a Tensor, with timing position given. Args: x: a Tensor with shape [batch, length, channels] position: a Tensor with shape [batch, length] min_timescale: a float max_timescale: a float Returns: a Tensor the same shape as x. """ channels = common_layers.shape_list(x)[2] num_timescales = channels // 2 log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / (tf.to_float(num_timescales) - 1)) inv_timescales = min_timescale * tf.exp( tf.to_float(tf.range(num_timescales)) * -log_timescale_increment) scaled_time = ( tf.expand_dims(tf.to_float(position), 2) * tf.expand_dims( tf.expand_dims(inv_timescales, 0), 0)) signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=2) signal = tf.pad(signal, [[0, 0], [0, 0], [0, tf.mod(channels, 2)]]) signal = common_layers.cast_like(signal, x) return signal
def prepare_decoder(targets, hparams): """Prepare decoder for images.""" targets_shape = common_layers.shape_list(targets) channels = hparams.num_channels curr_infer_length = None # during training, images are [batch, IMG_LEN, IMG_LEN, 3]. # At inference, they are [batch, curr_infer_length, 1, 1] if hparams.mode == tf.contrib.learn.ModeKeys.INFER: curr_infer_length = targets_shape[1] if hparams.block_raster_scan: assert hparams.img_len*channels % hparams.query_shape[1] == 0 assert hparams.img_len % hparams.query_shape[0] == 0 total_block_width = hparams.img_len*channels # Decoding is in block raster scan order. We divide the image into # hparams.query_shape blocks and then decode each block in raster scan. # To make that compatible with our inference pipeline, pad the target so # that rows is a multiple of query_shape and columns is a multiple of # hparams.img_len*channels curr_infer_length = targets_shape[1] block_padding_factor = total_block_width * hparams.query_shape[0] targets = tf.pad(targets, [ [0, 0], [0, -curr_infer_length % block_padding_factor], [0, 0], [0, 0]]) num_blocks = total_block_width // hparams.query_shape[1] # Reshape the image to represent blocks target_blocks = tf.reshape( targets, [targets_shape[0], -1, num_blocks, hparams.query_shape[0], hparams.query_shape[1]]) # Transpose to read the image in 2D fashion. targets = tf.transpose(target_blocks, [0, 1, 3, 2, 4]) else: # add padding to make sure the size of targets is a multiple of img_height # times number of channels. This is needed for positional encodings and # for doing the RGB lookup. padding_factor = channels * hparams.img_len targets = tf.pad(targets, [ [0, 0], [0, -curr_infer_length % padding_factor], [0, 0], [0, 0]]) targets = tf.reshape(targets, [targets_shape[0], -1, hparams.img_len, channels]) # Preprocess image x = prepare_image(targets, hparams, name="dec_channels") x_shape = common_layers.shape_list(x) if (hparams.dec_attention_type == AttentionType.LOCAL_2D or hparams.dec_attention_type == AttentionType.LOCAL_BLOCK): x = common_attention.right_shift_blockwise(x, hparams.query_shape) x = add_pos_signals(x, hparams, "dec_pos") else: # Add position signals x = tf.reshape(x, [targets_shape[0], x_shape[1]*x_shape[2], hparams.hidden_size]) x = common_layers.shift_right_3d(x) x = tf.reshape(x, [targets_shape[0], x_shape[1], x_shape[2], hparams.hidden_size]) x = add_pos_signals(x, hparams, "dec_pos") x = common_layers.cast_like(x, targets) return x, x_shape[1], x_shape[2]
def prepare_decoder(targets, hparams): """Prepare decoder for images.""" targets_shape = common_layers.shape_list(targets) channels = hparams.num_channels curr_infer_length = None # during training, images are [batch, IMG_LEN, IMG_LEN, 3]. # At inference, they are [batch, curr_infer_length, 1, 1] if hparams.mode == tf.estimator.ModeKeys.PREDICT: curr_infer_length = targets_shape[1] if hparams.block_raster_scan: assert hparams.img_len*channels % hparams.query_shape[1] == 0 assert hparams.img_len % hparams.query_shape[0] == 0 total_block_width = hparams.img_len*channels # Decoding is in block raster scan order. We divide the image into # hparams.query_shape blocks and then decode each block in raster scan. # To make that compatible with our inference pipeline, pad the target so # that rows is a multiple of query_shape and columns is a multiple of # hparams.img_len*channels curr_infer_length = targets_shape[1] block_padding_factor = total_block_width * hparams.query_shape[0] targets = tf.pad(targets, [ [0, 0], [0, -curr_infer_length % block_padding_factor], [0, 0], [0, 0]]) num_blocks = total_block_width // hparams.query_shape[1] # Reshape the image to represent blocks target_blocks = tf.reshape( targets, [targets_shape[0], -1, num_blocks, hparams.query_shape[0], hparams.query_shape[1]]) # Transpose to read the image in 2D fashion. targets = tf.transpose(target_blocks, [0, 1, 3, 2, 4]) else: # add padding to make sure the size of targets is a multiple of img_height # times number of channels. This is needed for positional encodings and # for doing the RGB lookup. padding_factor = channels * hparams.img_len targets = tf.pad(targets, [ [0, 0], [0, -curr_infer_length % padding_factor], [0, 0], [0, 0]]) targets = tf.reshape(targets, [targets_shape[0], -1, hparams.img_len, channels]) # Preprocess image x = prepare_image(targets, hparams, name="dec_channels") x_shape = common_layers.shape_list(x) if (hparams.dec_attention_type == AttentionType.LOCAL_2D or hparams.dec_attention_type == AttentionType.LOCAL_BLOCK): x = common_attention.right_shift_blockwise(x, hparams.query_shape) x = add_pos_signals(x, hparams, "dec_pos") else: # Add position signals x = tf.reshape(x, [targets_shape[0], x_shape[1]*x_shape[2], hparams.hidden_size]) x = common_layers.shift_right_3d(x) x = tf.reshape(x, [targets_shape[0], x_shape[1], x_shape[2], hparams.hidden_size]) x = add_pos_signals(x, hparams, "dec_pos") x = common_layers.cast_like(x, targets) return x, x_shape[1], x_shape[2]
def cast_grad_tpu(g, v): """Should match upstream t2t https://github.com/tensorflow/tensor2tensor/blob/1547c25571633f828ddd74accba76d07d8d043af/tensor2tensor/utils/optimize.py#L232 """ if v is not None and g is not None: g = common_layers.cast_like(g, v) if self._zero_grads and g is None: g = tf.zeros_like(v) return (g, v)
def bottom_simple(self, x, name, reuse): with tf.variable_scope(name, reuse=reuse): # Ensure the inputs are 3-D if len(x.get_shape()) == 4: x = tf.squeeze(x, axis=3) while len(x.get_shape()) < 3: x = tf.expand_dims(x, axis=-1) var = self._get_weights() x = common_layers.dropout_no_scaling( x, 1.0 - self._model_hparams.symbol_dropout) ret = common_layers.gather(var, x) if self._model_hparams.multiply_embedding_mode == "sqrt_depth": ret *= self._body_input_depth**0.5 ret *= tf.expand_dims( common_layers.cast_like(tf.not_equal(x, 0), ret), -1) return ret
def _resource_apply_dense(self, grad, handle): var = handle grad = tf.to_float(grad) grad_squared = tf.square(grad) + self._epsilon1 grad_squared_mean = tf.reduce_mean(grad_squared) decay_rate = self._call_if_callable(self._decay_rate) update_scale = self._call_if_callable(self._learning_rate) update_scale = tf.convert_to_tensor(update_scale, name="update_scale") update_scale = tf.cast(update_scale, grad_squared_mean.dtype.base_dtype) old_val = var if var.dtype.base_dtype == tf.bfloat16: old_val = tf.to_float(self._parameter_encoding.decode(old_val)) if self._multiply_by_parameter_scale: update_scale *= tf.to_float(self._parameter_scale(old_val)) # HACK: Make things dependent on grad. # This confounds the XLA rewriter and keeps it from fusing computations # across different variables. This fusion is a bad for HBM usage, since # it causes the gradients to persist in memory. decay_rate += grad_squared_mean * 1e-30 update_scale += grad_squared_mean * 1e-30 # END HACK mixing_rate = 1.0 - decay_rate shape = var.get_shape().as_list() updates = [] if self._should_use_factored_second_moment_estimate(shape): grad_squared_row_mean = tf.reduce_mean(grad_squared, -1) grad_squared_col_mean = tf.reduce_mean(grad_squared, -2) vr = self.get_slot(var, "vr") new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean) vc = self.get_slot(var, "vc") new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean) vr_update = tf.assign(vr, new_vr, use_locking=self._use_locking) vc_update = tf.assign(vc, new_vc, use_locking=self._use_locking) updates = [vr_update, vc_update] long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True) r_factor = tf.rsqrt(new_vr / long_term_mean) c_factor = tf.rsqrt(new_vc) x = grad * tf.expand_dims(r_factor, -1) * tf.expand_dims( c_factor, -2) else: v = self.get_slot(var, "v") new_v = decay_rate * v + mixing_rate * grad_squared v_update = tf.assign(v, new_v, use_locking=self._use_locking) updates = [v_update] x = grad * tf.rsqrt(new_v) if self._clipping_threshold is not None: clipping_denom = tf.maximum( 1.0, reduce_rms(x) / self._clipping_threshold) x /= clipping_denom subtrahend = update_scale * x if self._beta1: m = self.get_slot(var, "m") new_m = self._beta1 * tf.to_float(m) + (1.0 - self._beta1) * subtrahend subtrahend = new_m new_m = common_layers.cast_like(new_m, var) updates.append(tf.assign(m, new_m, use_locking=self._use_locking)) new_val = tf.to_float(old_val) - subtrahend if var.dtype.base_dtype == tf.bfloat16: new_val = self._parameter_encoding.encode(new_val, self._quantization_noise) if self._simulated_quantize_bits: new_val = quantization.simulated_quantize( var - subtrahend, self._simulated_quantize_bits, self._quantization_noise) new_val = tf.cast(new_val, var.dtype) var_update = tf.assign(var, new_val, use_locking=self._use_locking) updates = [var_update] + updates return tf.group(*updates)
def cast_grad(g, v): if v is not None and g is not None: g = common_layers.cast_like(g, v) if self._zero_grads and g is None: g = tf.zeros_like(v) return (g, v)
def _resource_apply_dense(self, grad, handle): var = handle grad = tf.to_float(grad) grad_squared = tf.square(grad) + self._epsilon1 grad_squared_mean = tf.reduce_mean(grad_squared) decay_rate = self._decay_rate update_scale = self._learning_rate old_val = var if var.dtype.base_dtype == tf.bfloat16: old_val = tf.to_float(self._parameter_encoding.decode(old_val)) if self._multiply_by_parameter_scale: update_scale *= tf.to_float(self._parameter_scale(old_val)) # HACK: Make things dependent on grad. # This confounds the XLA rewriter and keeps it from fusing computations # across different variables. This fusion is a bad for HBM usage, since # it causes the gradients to persist in memory. decay_rate += grad_squared_mean * 1e-30 update_scale += grad_squared_mean * 1e-30 # END HACK mixing_rate = 1.0 - decay_rate shape = var.get_shape().as_list() updates = [] if self._should_use_factored_second_moment_estimate(shape): grad_squared_row_mean = tf.reduce_mean(grad_squared, -1) grad_squared_col_mean = tf.reduce_mean(grad_squared, -2) vr = self.get_slot(var, "vr") new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean) vc = self.get_slot(var, "vc") new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean) vr_update = tf.assign(vr, new_vr, use_locking=self._use_locking) vc_update = tf.assign(vc, new_vc, use_locking=self._use_locking) updates = [vr_update, vc_update] long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True) r_factor = tf.rsqrt(new_vr / long_term_mean) c_factor = tf.rsqrt(new_vc) x = grad * tf.expand_dims(r_factor, -1) * tf.expand_dims(c_factor, -2) else: v = self.get_slot(var, "v") new_v = decay_rate * v + mixing_rate * grad_squared v_update = tf.assign(v, new_v, use_locking=self._use_locking) updates = [v_update] x = grad * tf.rsqrt(new_v) if self._clipping_threshold is not None: clipping_denom = tf.maximum(1.0, reduce_rms(x) / self._clipping_threshold) x /= clipping_denom subtrahend = update_scale * x if self._beta1: m = self.get_slot(var, "m") new_m = self._beta1 * tf.to_float(m) + (1.0 - self._beta1) * subtrahend subtrahend = new_m new_m = common_layers.cast_like(new_m, var) updates.append(tf.assign(m, new_m, use_locking=self._use_locking)) new_val = tf.to_float(old_val) - subtrahend if var.dtype.base_dtype == tf.bfloat16: new_val = self._parameter_encoding.encode( new_val, self._quantization_noise) if self._simulated_quantize_bits: new_val = quantization.simulated_quantize( var - subtrahend, self._simulated_quantize_bits, self._quantization_noise) var_update = tf.assign(var, new_val, use_locking=self._use_locking) updates = [var_update] + updates return tf.group(*updates)
def cast_grad(g, v): if v is not None and g is not None: g = common_layers.cast_like(g, v) return (g, v)
def transformer_prepare_encoder(inputs, target_space, hparams, features=None): """Prepare one shard of the model for the encoder. Args: inputs: a Tensor. target_space: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention """ ishape_static = inputs.shape.as_list() encoder_input = inputs if features and "inputs_segmentation" in features: # Packed dataset. Keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] targets_segmentation = features["targets_segmentation"] if (hasattr(hparams, "unidirectional_encoder") and hparams.unidirectional_encoder): tf.logging.info("Using unidirectional encoder") encoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(inputs)[1])) else: encoder_self_attention_bias = ( common_attention.attention_bias_same_segment( inputs_segmentation, inputs_segmentation)) encoder_decoder_attention_bias = ( common_attention.attention_bias_same_segment(targets_segmentation, inputs_segmentation)) else: encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) if (hasattr(hparams, "unidirectional_encoder") and hparams.unidirectional_encoder): tf.logging.info("Using unidirectional encoder") encoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(inputs)[1])) else: # Usual case - not a packed dataset. encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(inputs)[1]) if target_space is not None and hparams.get("use_target_space_embedding", True): # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding( target_space, 32, ishape_static[-1], name="target_space_embedding", dtype=hparams.get("activation_dtype", "float32")) emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) encoder_input += emb_target_space if hparams.pos == "timing": if inputs_position is not None: encoder_input = common_attention.add_timing_signal_1d_given_position( encoder_input, inputs_position) else: encoder_input = common_attention.add_timing_signal_1d(encoder_input) elif hparams.pos == "emb": encoder_input = common_attention.add_positional_embedding( encoder_input, hparams.max_length, "inputs_positional_embedding", inputs_position) encoder_self_attention_bias = common_layers.cast_like( encoder_self_attention_bias, encoder_input) encoder_decoder_attention_bias = common_layers.cast_like( encoder_decoder_attention_bias, encoder_input) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def cast_grad(g, v): if v is not None and g is not None: g = common_layers.cast_like(g, v) return (g, v)
def dot_product_area_attention(q, k, v, bias, dropout_rate=0.0, image_shapes=None, name=None, attention_image_summary=None, save_weights_to=None, dropout_broadcast_dims=None, max_area_width=1, max_area_height=1, memory_height=1, area_key_mode="mean", area_value_mode="sum", top_k_areas=0, area_temperature=1.0, training=True): """Dot-product area attention. Args: q: Tensor with shape [..., length_q, depth_k]. k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must match with q. v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must match with q. bias: bias Tensor (see attention_bias()) dropout_rate: a float. image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() name: an optional string attention_image_summary: the callback for making image summary of attention. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). dropout_broadcast_dims: an optional list of integers less than rank of q. Specifies in which dimensions to broadcast the dropout decisions. max_area_width: the max width allowed for an area. max_area_height: the max height allowed for an area. memory_height: the height of the memory. area_key_mode: the mode for computing area keys, which can be "mean", "concat", "sum", "sample_concat", and "sample_sum". area_value_mode: the mode for computing area values, which can be either "mean", or "sum". top_k_areas: Use the top key areas for attention. area_temperature: the temperature for attention softmax. training: indicating if it is in the training mode. Returns: Tensor with shape [..., length_q, depth_v]. """ tf.logging.info( "dot_product_area_attention: " "area_h=%d, area_w=%d, mem_h=%d, " "area_key_mode=%s, area_value_mode=%s, " "area_temperature=%f", max_area_height, max_area_width, memory_height, area_key_mode, area_value_mode, area_temperature) with tf.variable_scope(name, default_name="dot_product_area_attention", values=[q, k, v]) as scope: mem_shape = common_layers.shape_list(k) batch_size = mem_shape[0] head_size = mem_shape[1] length = mem_shape[2] depth = mem_shape[3] k_area = compute_area_key(tf.reshape(k, [-1, length, depth]), max_area_width=max_area_width, max_area_height=max_area_height, height=memory_height, mode=area_key_mode, training=training) if area_value_mode == "mean": v_area, _, _, _, _ = compute_area_features( tf.reshape(v, [-1, length, depth]), max_area_width=max_area_width, max_area_height=max_area_height, height=memory_height) elif area_value_mode == "max": v_area, _, _ = basic_pool(tf.reshape(v, [-1, length, depth]), max_area_width=max_area_width, max_area_height=max_area_height, height=memory_height, fn=tf.reduce_max) elif area_value_mode == "sum": _, _, v_area, _, _ = compute_area_features( tf.reshape(v, [-1, length, depth]), max_area_width=max_area_width, max_area_height=max_area_height, height=memory_height) else: raise ValueError("Unsupported area value mode=%s" % area_value_mode) k = tf.reshape(k_area, [batch_size, head_size, -1, depth]) v = tf.reshape(v_area, [batch_size, head_size, -1, depth]) logits = tf.matmul(q, k, transpose_b=True) # [..., length_q, length_kv] if bias is not None: bias = common_layers.cast_like(bias, logits) with tf.name_scope("compute_area_att_bias", values=[bias]): bias_shape = common_layers.shape_list(bias) mem_length = bias_shape[-1] bias_values = tf.reshape(tf.to_float(tf.less(bias, -1)), [-1, mem_length, 1]) _, _, padding_sum, _, _ = compute_area_features( bias_values, max_area_width=max_area_width, max_area_height=max_area_height, height=memory_height) bias = tf.where(tf.cast(tf.to_int32(padding_sum), tf.bool), tf.fill(tf.shape(padding_sum), -np.inf), tf.zeros_like(padding_sum, dtype=tf.float32)) bias = tf.reshape( bias, [bias_shape[0], bias_shape[1], bias_shape[2], -1]) logits += bias logits = logits / area_temperature weights = tf.nn.softmax(logits, name="attention_weights") if top_k_areas > 0: tf.logging.info("area_attention top_k_areas=%d", top_k_areas) top_k = tf.minimum( common_layers.shape_list(weights)[-1], top_k_areas) top_weights, _ = tf.nn.top_k(weights, k=top_k) min_values = tf.reduce_min(top_weights, -1, keepdims=True) weights = tf.where(tf.greater_equal(weights, min_values), weights, tf.zeros_like(weights)) weights = tf.div(weights, tf.reduce_sum(weights, -1, keepdims=True)) if save_weights_to is not None: save_weights_to[scope.name] = weights save_weights_to[scope.name + "/logits"] = logits # Drop out attention links for each head. weights = common_layers.dropout_with_broadcast_dims( weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims) if common_layers.should_generate_summaries( ) and attention_image_summary: attention_image_summary(weights, image_shapes) return tf.matmul(weights, v)
def transformer_prepare_encoder(inputs, target_space, hparams, features=None, type_ids=None, num_types=None, reuse_target_embedding=tf.AUTO_REUSE): """Prepare one shard of the model for the encoder. Args: inputs: a Tensor. target_space: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. type_ids: optional, an int64 Tensor of shape [batch, length] that allows for adding type embeddings, similar to positional embeddings. num_types: optional, an int that decides the number of types in type_ids. reuse_target_embedding: option to reuse variable name in the case that symbol modalities are reused between inputs/targets. Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention """ ishape_static = inputs.shape.as_list() encoder_input = inputs if features and "inputs_segmentation" in features: # Packed dataset. Keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] targets_segmentation = features["targets_segmentation"] if (hasattr(hparams, "unidirectional_encoder") and hparams.unidirectional_encoder): tf.logging.info("Using unidirectional encoder") encoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(inputs)[1])) else: encoder_self_attention_bias = ( common_attention.attention_bias_same_segment( inputs_segmentation, inputs_segmentation)) encoder_decoder_attention_bias = ( common_attention.attention_bias_same_segment( targets_segmentation, inputs_segmentation)) else: encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) if (hasattr(hparams, "unidirectional_encoder") and hparams.unidirectional_encoder): tf.logging.info("Using unidirectional encoder") encoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(inputs)[1])) else: # Usual case - not a packed dataset. encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(inputs)[1]) if target_space is not None and hparams.get("use_target_space_embedding", True): # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding( target_space, 32, ishape_static[-1], name="target_space_embedding", dtype=hparams.get("activation_dtype", "float32"), reuse=reuse_target_embedding) emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) encoder_input += emb_target_space if hparams.pos == "timing": if inputs_position is not None: encoder_input = common_attention.add_timing_signal_1d_given_position( encoder_input, inputs_position) else: encoder_input = common_attention.add_timing_signal_1d( encoder_input) elif hparams.pos == "timing_from_features": encoder_input = common_attention.add_timing_signals_from_features( encoder_input, features, hparams.position_features) elif hparams.pos == "emb": encoder_input = common_attention.add_positional_embedding( encoder_input, hparams.max_length, "inputs_positional_embedding", inputs_position) # Add type embeddings if type_ids is not None: if not num_types: raise ValueError("Need to set num_types as well.") encoder_input = common_attention.add_positional_embedding( encoder_input, num_types, "inputs_type_embedding", type_ids) encoder_self_attention_bias = common_layers.cast_like( encoder_self_attention_bias, encoder_input) encoder_decoder_attention_bias = common_layers.cast_like( encoder_decoder_attention_bias, encoder_input) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def dot_product_attention_mtsa( q, k, v, bias, dropout_rate=0.0, image_shapes=None, name=None, make_image_summary=True, save_weights_to=None, dropout_broadcast_dims=None, use_k_mtsa=True, afn_extra='none', afn_dot='exp', afn_multi='exp', bias_start=0., bi_direction=False, ): """Dot-product attention. Args: q: Tensor with shape [..., length_q, depth_k]. k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must match with q. v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must match with q. bias: bias Tensor (see attention_bias()) dropout_rate: a float. image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() name: an optional string make_image_summary: True if you want an image summary. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). dropout_broadcast_dims: an optional list of integers less than rank of q. Specifies in which dimensions to broadcast the dropout decisions. Returns: Tensor with shape [..., length_q, depth_v]. """ print("!!!!!dot_product_attention_mtsa!!!!!") with tf.variable_scope(name, default_name="dot_product_attention", values=[q, k, v]) as scope: # get dim dim_q = q.get_shape().as_list()[-1] dim_k = k.get_shape().as_list()[-1] dim_v = v.get_shape().as_list()[-1] # prepare multi_logits_scale_factor = 1. / math.sqrt( dim_v) if afn_multi.startswith('scaled') else 1. afn_extra, afn_dot, afn_multi = afn_name2fn(afn_extra), afn_name2fn( afn_dot), afn_name2fn(afn_multi) # if bias is not None: # inp_mask_1d = tf.to_float(tf.equal(bias, 0.)) # bs,1,1,vl # inp_mask_1d = tf.transpose(inp_mask_1d, [0, 1, 3, 2]) # bs,1,vl,1 # else: # inp_mask_1d = None # token2token self attention dot_logits = tf.matmul(q, k, transpose_b=True) # bs,hd,ql,vl if bias is not None: bias = common_layers.cast_like(bias, dot_logits) # 1/bs,1,ql/1,vl dot_logits += bias e_dot_logits = afn_dot(dot_logits) # bs,hd,ql,vl if bi_direction: head_num = v.get_shape().as_list()[1] ql, vl = tf.shape(q)[-2], tf.shape(v)[-2] assert head_num is not None assert head_num % 2 == 0 ones_mat = tf.ones([ql, vl], tf.float32) mul_mask_fw = tf.matrix_band_part(ones_mat, -1, 0) # Lower triangular part. mul_mask_bw = tf.matrix_band_part(ones_mat, 0, -1) # Upper triangular part. mul_mask_fw_tile = tf.tile(tf.expand_dims(mul_mask_fw, 0), [head_num // 2, 1, 1]) mul_mask_bw_tile = tf.tile(tf.expand_dims(mul_mask_bw, 0), [head_num // 2, 1, 1]) mul_mask = tf.expand_dims(tf.concat( [mul_mask_fw_tile, mul_mask_bw_tile], axis=0), axis=0) e_dot_logits *= mul_mask # source2token self-attention multi_logits = multi_head_dense_layer( k if use_k_mtsa else v, dim_v, True, bias_start if afn_extra is None else 0., 'multi_logits1') if afn_extra is not None: # use one extra layer for multi-dim multi_logits = multi_head_dense_layer(afn_extra(multi_logits), dim_v, True, bias_start, 'multi_logits2') e_multi_logits = afn_multi(multi_logits * multi_logits_scale_factor) # bs,hd,vl,vd # if inp_mask_1d is not None: # use mask for exp_logits # e_multi_logits *= inp_mask_1d # mtsa accum_z_deno = tf.matmul(e_dot_logits, e_multi_logits) # bs,hd,ql,vd accum_z_deno = tf.where( # in case of NaN and Inf tf.greater(accum_z_deno, tf.zeros_like(accum_z_deno)), accum_z_deno, tf.ones_like(accum_z_deno)) # attention dropout e_dot_logits = common_layers.dropout_with_broadcast_dims( e_dot_logits, math.sqrt(1. - dropout_rate), broadcast_dims=dropout_broadcast_dims) e_multi_logits = common_layers.dropout_with_broadcast_dims( e_multi_logits, math.sqrt(1. - dropout_rate), broadcast_dims=dropout_broadcast_dims) rep_mul_score = v * e_multi_logits # bs,hd,vl,vd accum_rep_mul_score = tf.matmul(e_dot_logits, rep_mul_score) # bs,hd,ql,vd # calculate the final attention results attn_res = accum_rep_mul_score / accum_z_deno # if inp_mask_1d is not None: # use mask for output # attn_res *= inp_mask_1d # ============ for vis ======= weights = e_dot_logits / (tf.reduce_sum( e_dot_logits, axis=-1, keepdims=True, name="attention_weights") + 0.00001) if save_weights_to is not None: save_weights_to[scope.name] = weights save_weights_to[scope.name + "/logits"] = dot_logits if common_layers.should_generate_summaries() and make_image_summary: common_attention.attention_image_summary(weights, image_shapes) return attn_res