Ejemplo n.º 1
0
 def call(self, inputs):
     if isinstance(inputs, list):
         x, x_mask = inputs
     else:
         x, x_mask = inputs, None
     seq_dim = K.int_shape(x)[-1]
     # 补足长度,保证可以reshape
     seq_len = K.shape(x)[1]
     pad_len = self.rate - seq_len % self.rate
     x = K.temporal_padding(x, (0, pad_len))
     if x_mask is not None:
         x_mask = K.temporal_padding(x_mask, (0, pad_len))
     new_seq_len = K.shape(x)[1]
     # 变换shape
     x = K.reshape(x, (-1, new_seq_len // self.rate, self.rate, seq_dim))
     x = K.permute_dimensions(x, (0, 2, 1, 3))
     x = K.reshape(x, (-1, new_seq_len // self.rate, seq_dim))
     if x_mask is not None:
         x_mask = K.reshape(x_mask,
                            (-1, new_seq_len // self.rate, self.rate, 1))
         x_mask = K.permute_dimensions(x_mask, (0, 2, 1, 3))
         x_mask = K.reshape(x_mask, (-1, new_seq_len // self.rate, 1))
     # 做attention
     if x_mask is not None:
         x = self.reuse(self.attention, [x, x, x, x_mask, x_mask])
     else:
         x = self.reuse(self.attention, [x, x, x])
     # 恢复shape
     x = K.reshape(x,
                   (-1, self.rate, new_seq_len // self.rate, self.out_dim))
     x = K.permute_dimensions(x, (0, 2, 1, 3))
     x = K.reshape(x, (-1, new_seq_len, self.out_dim))
     x = x[:, :-pad_len]
     return x
Ejemplo n.º 2
0
 def pool(i):
     if pool_size > 1:
         o = Lambda(lambda x: K.temporal_padding(x, (pool_size - 1, 0)))(i)
         o = MaxPooling1D(pool_size, strides=1, padding='valid')(o)
     else:
         o = i
     return o
Ejemplo n.º 3
0
def merge_action_crossed_fn(x):
    input = x[0]
    is_crossed_orig = x[1]
    is_crossed_orig = K.switch(is_crossed_orig > 0.1,
                               K.maximum(is_crossed_orig, 1),
                               is_crossed_orig * 0)
    action_count = input.shape[1]
    is_crossed = K.expand_dims(is_crossed_orig, axis=1)
    is_crossed = K.expand_dims(is_crossed, axis=1)
    is_crossed = K.temporal_padding(is_crossed, (0, action_count - 1))
    is_crossed = K.squeeze(is_crossed, axis=2)
    is_crossed_mask = K.expand_dims(is_crossed_orig, axis=1)
    is_crossed_mask = K.repeat_elements(is_crossed_mask, action_count, axis=1)
    res_crossed = (1 - is_crossed_mask) * input + is_crossed
    carpisma_timer_orig = x[2]
    carpisma_timer_orig = K.squeeze(carpisma_timer_orig, axis=2)
    is_carpisma = K.sum(carpisma_timer_orig, axis=1)
    is_carpisma = K.switch(is_carpisma > 0.1, K.maximum(is_carpisma, 1),
                           is_carpisma * 0)
    not_carpisma = 1 - is_carpisma
    print("carpisma timer", carpisma_timer_orig)
    print("is carpisma", is_carpisma.shape)
    print("not carpisma", not_carpisma.shape)
    not_carpisma = K.expand_dims(not_carpisma, axis=1)
    not_carpisma = K.repeat_elements(not_carpisma, action_count, axis=1)
    res_crossed = res_crossed * not_carpisma
    res = K.concatenate([res_crossed, carpisma_timer_orig], axis=1)
    return res
Ejemplo n.º 4
0
def divisible_temporal_padding(x, n):
    """
    将一维向量序列右padding到长度能被n整除
    """
    r_len = K.shape(x)[1] % n
    p_len = K.switch(r_len > 0, n - r_len, 0)
    return K.temporal_padding(x, (0, p_len))
Ejemplo n.º 5
0
 def call(self, inputs, **kwargs):
     source, target = inputs
     source_shape = K.shape(source)
     target_shape = K.shape(target)
     length_diff = target_shape[1] - source_shape[1]
     return K.temporal_padding(source,
                               padding=(length_diff // 2,
                                        length_diff - length_diff // 2))
Ejemplo n.º 6
0
 def pool(i):
     if pool_size > 1:
         o = Lambda(lambda x: K.temporal_padding(x, (pool_size - 1, 0)))(i)
         o = AveragePooling1D(pool_size, strides=1, padding='valid')(o)
         scale = tf.reshape(pool_size / tf.range(1, pool_size+1, dtype=o.dtype), (1, pool_size, 1))
         o = tf.concat([scale * o[:, :pool_size, :], o[:, pool_size:, :]], axis=1)
     else:
         o = i
     return o
Ejemplo n.º 7
0
 def add_to_timer_fn(x):
     tavuk_crossed = x[1]
     if len(tavuk_crossed.shape) > 2:
         tavuk_crossed = K.sum(tavuk_crossed,
                               axis=[1, 2, 3],
                               keepdims=False)
     tavuk_crossed = K.expand_dims(tavuk_crossed, axis=1)
     tavuk_crossed = K.expand_dims(tavuk_crossed, axis=1)
     not_crossed = 1 - K.repeat_elements(tavuk_crossed, timer_count, axis=1)
     tavuk_crossed = K.temporal_padding(tavuk_crossed, (0, timer_count - 1))
     res = tavuk_crossed + x[0] * not_crossed
     return res
Ejemplo n.º 8
0
def extract_seq_patches(x, kernel_size, rate):
    """x.shape = [None, seq_len, seq_dim]
    滑动地把每个窗口的x取出来,为做局部attention作准备。
    """
    seq_dim = K.int_shape(x)[-1]
    seq_len = K.shape(x)[1]
    k_size = kernel_size + (rate - 1) * (kernel_size - 1)
    p_right = (k_size - 1) // 2
    p_left = k_size - 1 - p_right
    x = K.temporal_padding(x, (p_left, p_right))
    xs = [x[:, i:i + seq_len] for i in range(0, k_size, rate)]
    x = K.concatenate(xs, 2)
    return K.reshape(x, (-1, seq_len, kernel_size, seq_dim))
Ejemplo n.º 9
0
 def call(self, inputs):
     if isinstance(inputs, list):
         x, x_mask = inputs
     else:
         x, x_mask = inputs, None
     seq_dim = K.int_shape(x)[-1]
     # 补足长度,保证可以reshape
     seq_len = K.shape(x)[1]
     pad_len = self.rate - seq_len % self.rate
     x = K.temporal_padding(x, (0, pad_len))
     if x_mask is not None:
         x_mask = K.temporal_padding(x_mask, (0, pad_len))
     new_seq_len = K.shape(x)[1]
     # 经过padding后shape可能变为None,所以重新声明一下shape
     x = K.reshape(x, (-1, new_seq_len, seq_dim))
     # 线性变换
     qw = self.reuse(self.q_dense, x)
     kw = self.reuse(self.k_dense, x)
     vw = self.reuse(self.v_dense, x)
     # 提取局部特征
     kernel_size = 1 + 2 * self.neighbors
     # shape=[None, seq_len, kernel_size, out_dim]
     kwp = extract_seq_patches(kw, kernel_size, self.rate)
     # shape=[None, seq_len, kernel_size, out_dim]
     vwp = extract_seq_patches(vw, kernel_size, self.rate)
     if x_mask is not None:
         xp_mask = extract_seq_patches(x_mask, kernel_size, self.rate)
     # 形状变换
     qw = K.reshape(qw, (-1, new_seq_len // self.rate, self.rate,
                         self.heads, self.key_size))
     kw = K.reshape(kw, (-1, new_seq_len // self.rate, self.rate,
                         self.heads, self.key_size))
     vw = K.reshape(vw, (-1, new_seq_len // self.rate, self.rate,
                         self.heads, self.size_per_head))
     kwp = K.reshape(kwp, (-1, new_seq_len // self.rate, self.rate,
                           kernel_size, self.heads, self.key_size))
     vwp = K.reshape(vwp, (-1, new_seq_len // self.rate, self.rate,
                           kernel_size, self.heads, self.size_per_head))
     if x_mask is not None:
         x_mask = K.reshape(x_mask,
                            (-1, new_seq_len // self.rate, self.rate, 1, 1))
         xp_mask = K.reshape(
             xp_mask,
             (-1, new_seq_len // self.rate, self.rate, kernel_size, 1, 1))
     # 维度置换
     # shape=[None, heads, r, seq_len // r, size]
     qw = K.permute_dimensions(qw, (0, 3, 2, 1, 4))
     kw = K.permute_dimensions(kw, (0, 3, 2, 1, 4))
     vw = K.permute_dimensions(vw, (0, 3, 2, 1, 4))
     qwp = K.expand_dims(qw, 4)
     # shape=[None, heads, r, seq_len // r, kernel_size, out_dim]
     kwp = K.permute_dimensions(kwp, (0, 4, 2, 1, 3, 5))
     vwp = K.permute_dimensions(vwp, (0, 4, 2, 1, 3, 5))
     if x_mask is not None:
         x_mask = K.permute_dimensions(x_mask, (0, 3, 2, 1, 4))
         xp_mask = K.permute_dimensions(xp_mask, (0, 4, 2, 1, 3, 5))
     # Attention1
     a = K.batch_dot(qw, kw, [4, 4]) / self.key_size**0.5
     a = K.permute_dimensions(a, (0, 1, 2, 4, 3))
     a = to_mask(a, x_mask, 'add')
     a = K.permute_dimensions(a, (0, 1, 2, 4, 3))
     if self.mask_right:
         ones = K.ones_like(a[:1, :1, :1])
         mask = (ones - K.tf.matrix_band_part(ones, -1, 0)) * 1e10
         a = a - mask
     # Attention2
     ap = K.batch_dot(qwp, kwp, [5, 5]) / self.key_size**0.5
     ap = K.permute_dimensions(ap, (0, 1, 2, 3, 5, 4))
     if x_mask is not None:
         ap = to_mask(ap, xp_mask, 'add')
     ap = K.permute_dimensions(ap, (0, 1, 2, 3, 5, 4))
     if self.mask_right:
         mask = np.ones((1, kernel_size))
         mask[:, -self.neighbors:] = 0
         mask = (1 - K.constant(mask)) * 1e10
         for _ in range(4):
             mask = K.expand_dims(mask, 0)
         ap = ap - mask
     ap = ap[..., 0, :]
     # 合并两个Attention
     A = K.concatenate([a, ap], -1)
     A = K.softmax(A)
     a, ap = A[..., :K.shape(a)[-1]], A[..., K.shape(a)[-1]:]
     # 完成输出1
     o1 = K.batch_dot(a, vw, [4, 3])
     # 完成输出2
     ap = K.expand_dims(ap, -2)
     o2 = K.batch_dot(ap, vwp, [5, 4])
     o2 = o2[..., 0, :]
     # 完成输出
     o = o1 + o2
     o = to_mask(o, x_mask, 'mul')
     o = K.permute_dimensions(o, (0, 3, 2, 1, 4))
     o = K.reshape(o, (-1, new_seq_len, self.out_dim))
     o = o[:, :-pad_len]
     return o