def double_linear_logits(args, size, bias, bias_start=0.0, scope=None, mask=None, wd=0.0, input_keep_prob=1.0, is_train=None): with tf.variable_scope(scope or "Double_Linear_Logits"): first = tf.tanh( linear(args, size, bias, bias_start=bias_start, scope='first', wd=wd, input_keep_prob=input_keep_prob, is_train=is_train)) second = linear(first, 1, bias, bias_start=bias_start, squeeze=True, scope='second', wd=wd, input_keep_prob=input_keep_prob, is_train=is_train) if mask is not None: second = exp_mask(second, mask) return second
def highway_layer(arg, bias, bias_start=0.0, scope=None, wd=0.0, input_keep_prob=1.0, is_train=None): with tf.variable_scope(scope or "highway_layer"): d = arg.get_shape()[-1] # embedding dim trans = linear([arg], d, bias, bias_start=bias_start, scope='trans', wd=wd, input_keep_prob=input_keep_prob, is_train=is_train) trans = tf.nn.relu(trans) gate = linear([arg], d, bias, bias_start=bias_start, scope='gate', wd=wd, input_keep_prob=input_keep_prob, is_train=is_train) gate = tf.nn.sigmoid(gate) out = gate * trans + (1 - gate) * arg return out
def __init__(self, cfg, num_modules): super().__init__() self.cfg = cfg self.num_modules = num_modules control_dim = cfg.MODEL.KB_DIM if cfg.MODEL.CTRL.USE_WORD_EMBED: control_dim = cfg.MODEL.EMBED_DIM dim = cfg.MODEL.LSTM_DIM self.shared_control_proj = linear(dim, dim) self.position_aware = nn.ModuleList() for i in range(cfg.MODEL.T_CTRL): self.position_aware.append(linear(dim, dim)) self.control_question = linear(dim + control_dim, dim) self.attn = linear(dim, 1) if self.cfg.MODEL.CTRL.LINEAR_MODULE_WEIGHTS: self.module_fc = nn.Linear(dim, num_modules, bias=False) else: self.module_fc = nn.Sequential( nn.Linear(dim, cfg.MODEL.LSTM_DIM), nn.ELU(), nn.Linear(cfg.MODEL.LSTM_DIM, num_modules))
def linear_logits(args, bias, bias_start=0.0, scope=None, mask=None, wd=0.0, input_keep_prob=1.0, is_train=None): with tf.variable_scope(scope or "Linear_Logits"): logits = linear(args, 1, bias, bias_start=bias_start, squeeze=True, scope='first', wd=wd, input_keep_prob=input_keep_prob, is_train=is_train) if mask is not None: logits = exp_mask(logits, mask) return logits
def get_logits(args, size, bias, bias_start=0.0, scope=None, mask=None, wd=0.0, input_keep_prob=1.0, is_train=None, func=None): if func is None: func = "linear" if func == 'sum': return sum_logits(args, mask=mask, name=scope) elif func == 'linear': return linear_logits(args, bias, bias_start=bias_start, scope=scope, mask=mask, wd=wd, input_keep_prob=input_keep_prob, is_train=is_train) elif func == 'double': return double_linear_logits(args, size, bias, bias_start=bias_start, scope=scope, mask=mask, wd=wd, input_keep_prob=input_keep_prob, is_train=is_train) elif func == 'dot': assert len(args) == 2 arg = args[0] * args[1] return sum_logits([arg], mask=mask, name=scope) elif func == 'mul_linear': assert len(args) == 2 arg = args[0] * args[1] return linear_logits([arg], bias, bias_start=bias_start, scope=scope, mask=mask, wd=wd, input_keep_prob=input_keep_prob, is_train=is_train) elif func == 'proj': assert len(args) == 2 d = args[1].get_shape()[-1] proj = linear([args[0]], d, False, bias_start=bias_start, scope=scope, wd=wd, input_keep_prob=input_keep_prob, is_train=is_train) return sum_logits([proj * args[1]], mask=mask) elif func == 'tri_linear': assert len(args) == 2 new_arg = args[0] * args[1] return linear_logits([args[0], args[1], new_arg], bias, bias_start=bias_start, scope=scope, mask=mask, wd=wd, input_keep_prob=input_keep_prob, is_train=is_train) else: raise Exception()
def directional_attention_with_dense(rep_tensor, rep_mask, direction=None, scope=None, keep_prob=1., is_train=None, wd=0., activation='elu', tensor_dict=None, name=None): def scaled_tanh(x, scale=5.): return scale * tf.nn.tanh(1./scale * x) bs, sl, vec = tf.shape(rep_tensor)[0], tf.shape(rep_tensor)[1], tf.shape(rep_tensor)[2] ivec = rep_tensor.get_shape()[2] with tf.variable_scope(scope or 'directional_attention_%s' % direction or 'diag'): # mask generation sl_indices = tf.range(sl, dtype=tf.int32) sl_col, sl_row = tf.meshgrid(sl_indices, sl_indices) if direction is None: direct_mask = tf.cast(tf.diag(- tf.ones([sl], tf.int32)) + 1, tf.bool) else: if direction == 'forward': direct_mask = tf.greater(sl_row, sl_col) else: direct_mask = tf.greater(sl_col, sl_row) direct_mask_tile = tf.tile(tf.expand_dims(direct_mask, 0), [bs, 1, 1]) # bs,sl,sl rep_mask_tile = tf.tile(tf.expand_dims(rep_mask, 1), [1, sl, 1]) # bs,sl,sl attn_mask = tf.logical_and(direct_mask_tile, rep_mask_tile) # bs,sl,sl # non-linear rep_map = bn_dense_layer(rep_tensor, ivec, True, 0., 'bn_dense_map', activation, False, wd, keep_prob, is_train) rep_map_tile = tf.tile(tf.expand_dims(rep_map, 1), [1, sl, 1, 1]) # bs,sl,sl,vec rep_map_dp = dropout(rep_map, keep_prob, is_train) # attention with tf.variable_scope('attention'): # bs,sl,sl,vec f_bias = tf.get_variable('f_bias',[ivec], tf.float32, tf.constant_initializer(0.)) dependent = linear(rep_map_dp, ivec, False, scope='linear_dependent') # bs,sl,vec dependent_etd = tf.expand_dims(dependent, 1) # bs,1,sl,vec head = linear(rep_map_dp, ivec, False, scope='linear_head') # bs,sl,vec head_etd = tf.expand_dims(head, 2) # bs,sl,1,vec logits = scaled_tanh(dependent_etd + head_etd + f_bias, 5.0) # bs,sl,sl,vec logits_masked = exp_mask_for_high_rank(logits, attn_mask) attn_score = tf.nn.softmax(logits_masked, 2) # bs,sl,sl,vec attn_score = mask_for_high_rank(attn_score, attn_mask) attn_result = tf.reduce_sum(attn_score * rep_map_tile, 2) # bs,sl,vec with tf.variable_scope('output'): o_bias = tf.get_variable('o_bias',[ivec], tf.float32, tf.constant_initializer(0.)) # input gate fusion_gate = tf.nn.sigmoid( linear(rep_map, ivec, True, 0., 'linear_fusion_i', False, wd, keep_prob, is_train) + linear(attn_result, ivec, True, 0., 'linear_fusion_a', False, wd, keep_prob, is_train) + o_bias) output = fusion_gate * rep_map + (1-fusion_gate) * attn_result output = mask_for_high_rank(output, rep_mask) # save attn if tensor_dict is not None and name is not None: tensor_dict[name + '_dependent'] = dependent tensor_dict[name + '_head'] = head tensor_dict[name] = attn_score tensor_dict[name + '_gate'] = fusion_gate return output
def multihead_attention(query, memory, bias, key_size, value_size, output_size, num_heads, keep_prob=None, data_format="NHWC", attention_function="dot_product", dtype=None, scope=None): """ Multihead scaled-dot-product attention with input/output transformations. Args: query: a Tensor with shape [batch, length_q, channels] if data_format is `NHWC`, [batch, channels, length_q] if data_format is `NCHW` memory: a Tensor with shape [batch, length_m, channels] if data_format is `NHWC`, [batch, channels, length_q] if data_format is `NCHW` bias: bias Tensor (see attention_bias()) key_size: an integer value_size: an integer output_size: an integer num_heads: an integer dividing total_key_depth and total_value_depth keep_prob: a floating point number summaries: a boolean image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() data_format: "NHWC" or "NCHW" attention_function: "dot_product" or "additive" dtype: an optional instance of tf.DType scope: an optional string Returns: A Tensor. """ if key_size % num_heads != 0: raise ValueError("Key size (%d) must be divisible by the number of " "attention heads (%d)." % (key_size, num_heads)) if value_size % num_heads != 0: raise ValueError("Value size (%d) must be divisible by the number of " "attention heads (%d)." % (value_size, num_heads)) with tf.variable_scope(scope, default_name="multihead_attention", values=[query, memory], dtype=dtype): axis = 2 if memory is None: # self attention size = key_size * 2 + value_size combined = linear(query, size, True, True, data_format=data_format, scope="qkv_transform") q, k, v = tf.split(combined, [key_size, key_size, value_size], axis=axis) else: q = linear(query, key_size, True, data_format=data_format, scope="q_transform") combined = linear(memory, key_size + value_size, True, data_format=data_format, scope="kv_transform") k, v = tf.split(combined, [key_size, value_size], axis=axis) # split heads q = _split_heads(q, num_heads) k = _split_heads(k, num_heads) v = _split_heads(v, num_heads) # scale query if attention_function == "dot_product": key_depth_per_head = key_size // num_heads q *= key_depth_per_head**-0.5 # attention x = dot_product_attention(q, k, v, bias, keep_prob) # combine heads x = _combine_heads(x) x = linear(x, output_size, True, data_format=data_format, scope="output_transform") return x