Ejemplo n.º 1
0
def dot_product(x,kernel):

    if K.backend() == 'tensorflow':
        return K.squeeze(K.dot(x, K.expand_dims(kernel)), axis=-1)
    else:
        return K.dot(x, kernel)
Ejemplo n.º 2
0
 def call(self, x, mask=None):
     eij = K.tanh(K.dot(x, self.W) + self.b)
     ai = K.exp(eij)
     weights = ai / K.sum(ai, axis=1).dimshuffle(0, 'x')
     weighted_input = x * weights.dimshuffle(0, 1, 'x')
     return weighted_input.sum(axis=1)
Ejemplo n.º 3
0
 def call(self, x, mask=None):
     output = K.dot(x, self.W)
     if self.bias:
         output += self.b
     output = K.switch(self.sparse_mask, output, -1e20)
     return self.activation(output)
Ejemplo n.º 4
0
    def step(self, inputs, states):
        h_tm1 = states[0]
        c_tm1 = states[1]
        dp_mask = states[2]
        rec_dp_mask = states[3]
        x_input = states[4]

        # alignment model
        h_att = K.repeat(h_tm1, self.timestep_dim)
        att = _time_distributed_dense(x_input,
                                      self.attention_weights,
                                      self.attention_bias,
                                      output_dim=K.int_shape(
                                          self.attention_weights)[1])
        attention_ = self.attention_activation(
            K.dot(h_att, self.attention_recurrent_weights) + att)
        attention_ = K.squeeze(
            K.dot(attention_, self.attention_recurrent_bias), 2)

        alpha = K.exp(attention_)

        if dp_mask is not None:
            alpha *= dp_mask[0]

        alpha /= K.sum(alpha, axis=1, keepdims=True)
        alpha_r = K.repeat(alpha, self.input_dim)
        alpha_r = K.permute_dimensions(alpha_r, (0, 2, 1))

        # make context vector (soft attention after Bahdanau et al.)
        z_hat = x_input * alpha_r
        context_sequence = z_hat
        z_hat = K.sum(z_hat, axis=1)

        if self.implementation == 2:
            z = K.dot(inputs * dp_mask[0], self.kernel)
            z += K.dot(h_tm1 * rec_dp_mask[0], self.recurrent_kernel)
            z += K.dot(z_hat, self.attention_kernel)

            if self.use_bias:
                z = K.bias_add(z, self.bias)

            z0 = z[:, :self.units]
            z1 = z[:, self.units:2 * self.units]
            z2 = z[:, 2 * self.units:3 * self.units]
            z3 = z[:, 3 * self.units:]

            i = self.recurrent_activation(z0)
            f = self.recurrent_activation(z1)
            c = f * c_tm1 + i * self.activation(z2)
            o = self.recurrent_activation(z3)
        else:
            if self.implementation == 0:
                x_i = inputs[:, :self.units]
                x_f = inputs[:, self.units:2 * self.units]
                x_c = inputs[:, 2 * self.units:3 * self.units]
                x_o = inputs[:, 3 * self.units:]
            elif self.implementation == 1:
                x_i = K.dot(inputs * dp_mask[0], self.kernel_i) + self.bias_i
                x_f = K.dot(inputs * dp_mask[1], self.kernel_f) + self.bias_f
                x_c = K.dot(inputs * dp_mask[2], self.kernel_c) + self.bias_c
                x_o = K.dot(inputs * dp_mask[3], self.kernel_o) + self.bias_o
            else:
                raise ValueError('Unknown `implementation` mode.')

            i = self.recurrent_activation(
                x_i + K.dot(h_tm1 * rec_dp_mask[0], self.recurrent_kernel_i) +
                K.dot(z_hat, self.attention_i))
            f = self.recurrent_activation(
                x_f + K.dot(h_tm1 * rec_dp_mask[1], self.recurrent_kernel_f) +
                K.dot(z_hat, self.attention_f))
            c = f * c_tm1 + i * self.activation(
                x_c + K.dot(h_tm1 * rec_dp_mask[2], self.recurrent_kernel_c) +
                K.dot(z_hat, self.attention_c))
            o = self.recurrent_activation(
                x_o + K.dot(h_tm1 * rec_dp_mask[3], self.recurrent_kernel_o) +
                K.dot(z_hat, self.attention_o))
        h = o * self.activation(c)
        if 0 < self.dropout + self.recurrent_dropout:
            h._uses_learning_phase = True

        if self.return_attention:
            return context_sequence, [h, c]
        else:
            return h, [h, c]