def __call__(self, inputs, state, scope=None): """ :param inputs: [N, d + JQ + JQ * d] :param state: [N, d] :param scope: :return: """ with tf.variable_scope(scope or self.__class__.__name__): c_prev, h_prev = state x = tf.slice(inputs, [0, 0], [-1, self._input_size]) q_mask = tf.slice(inputs, [0, self._input_size], [-1, self._q_len]) # [N, JQ] qs = tf.slice(inputs, [0, self._input_size + self._q_len], [-1, -1]) qs = tf.reshape(qs, [-1, self._q_len, self._input_size]) # [N, JQ, d] x_tiled = tf.tile(tf.expand_dims(x, 1), [1, self._q_len, 1]) # [N, JQ, d] h_prev_tiled = tf.tile(tf.expand_dims(h_prev, 1), [1, self._q_len, 1]) # [N, JQ, d] f = tf.tanh( linear([qs, x_tiled, h_prev_tiled], self._input_size, True, scope='f')) # [N, JQ, d] a = tf.nn.softmax( exp_mask(linear(f, 1, True, squeeze=True, scope='a'), q_mask)) # [N, JQ] q = tf.reduce_sum(qs * tf.expand_dims(a, -1), 1) z = tf.concat([x, q], 1) # [N, 2d] return self._cell(z, state)
def linear_logits(args, bias, bias_start=0.0, scope=None, mask=None, wd=0.0, input_keep_prob=1.0, is_train=None): with tf.variable_scope(scope or "Linear_Logits"): # args = [N, M, JX, JQ, 6d] logits = linear(args, 1, bias, bias_start=bias_start, squeeze=True, scope='first', wd=wd, input_keep_prob=input_keep_prob, is_train=is_train) if mask is not None: logits = exp_mask(logits, mask) return logits
def double_linear_logits(args, size, bias, bias_start=0.0, scope=None, mask=None, wd=0.0, input_keep_prob=1.0, is_train=None): with tf.variable_scope(scope or "Double_Linear_Logits"): first = tf.tanh(linear(args, size, bias, bias_start=bias_start, scope='first', wd=wd, input_keep_prob=input_keep_prob, is_train=is_train)) second = linear(first, 1, bias, bias_start=bias_start, squeeze=True, scope='second', wd=wd, input_keep_prob=input_keep_prob, is_train=is_train) if mask is not None: second = exp_mask(second, mask) return second
def softmax(logits, mask=None, scope=None): with tf.name_scope(scope or "Softmax"): if mask is not None: logits = exp_mask(logits, mask) flat_logits = flatten(logits, 1) flat_out = tf.nn.softmax(flat_logits) out = reconstruct(flat_out, logits, 1) return out
def sum_logits(args, mask=None, name=None): with tf.name_scope(name or "sum_logits"): if args is None or (nest.is_sequence(args) and not args): raise ValueError("`args` must be specified") if not nest.is_sequence(args): args = [args] rank = len(args[0].get_shape()) logits = sum(tf.reduce_sum(arg, rank-1) for arg in args) if mask is not None: logits = exp_mask(logits, mask) return logits
def __call__(self, inputs, state, scope=None): """ :param inputs: [N*B, I + B] :param state: [N*B, d] :param scope: :return: [N*B, d] """ with tf.variable_scope(scope or self.__class__.__name__): d = self.state_size x = tf.slice(inputs, [0, 0], [-1, self._input_size]) # [N*B, I] mask = tf.slice(inputs, [0, self._input_size], [-1, -1]) # [N*B, B] B = tf.shape(mask)[1] prev_state = tf.expand_dims(tf.reshape(state, [-1, B, d]), 1) # [N, B, d] -> [N, 1, B, d] mask = tf.tile(tf.expand_dims(tf.reshape(mask, [-1, B, B]), -1), [1, 1, 1, d]) # [N, B, B, d] # prev_state = self._reduce_func(tf.tile(prev_state, [1, B, 1, 1]), 2) prev_state = self._reduce_func(exp_mask(prev_state, mask), 2) # [N, B, d] prev_state = tf.reshape(prev_state, [-1, d]) # [N*B, d] return self._cell(x, prev_state)