Example #1
0
 def test_batch_dot_shape(self):
     x_batch = KTF.ones(shape=(32, 20))
     y_batch = KTF.ones(shape=(32, 20))
     xy_batch_dot = KTF.batch_dot(x_batch, y_batch, axes=1)
     assert_allclose(KTF.eval(xy_batch_dot), np.ones((32, 1)) * 20, atol=1e-05)
     xy_batch_dot = KTF.batch_dot(x_batch, y_batch, axes=0)
     assert_allclose(KTF.eval(xy_batch_dot), np.ones((20, 1)) * 32, atol=1e-05)
     # making sure swapping axes when ndim == 2 works
     x_batch = KTF.ones(shape=(32, 20))
     y_batch = KTF.ones(shape=(20, 32))
     xy_batch_dot = KTF.batch_dot(x_batch, y_batch, axes=(0, 1))
     assert_allclose(KTF.eval(xy_batch_dot), np.ones((20, 1)) * 32, atol=1e-05)
     xy_batch_dot = KTF.batch_dot(x_batch, y_batch, axes=(1, 0))
     assert_allclose(KTF.eval(xy_batch_dot), np.ones((32, 1)) * 20, atol=1e-05)
Example #2
0
 def test_batch_dot_shape(self):
     x_batch = KTF.ones(shape=(32, 20))
     y_batch = KTF.ones(shape=(32, 20))
     xy_batch_dot = KTF.batch_dot(x_batch, y_batch, axes=1)
     assert_allclose(KTF.eval(xy_batch_dot), np.ones((32, 1)) * 20, atol=1e-05)
     xy_batch_dot = KTF.batch_dot(x_batch, y_batch, axes=0)
     assert_allclose(KTF.eval(xy_batch_dot), np.ones((20, 1)) * 32, atol=1e-05)
     # making sure swapping axes when ndim == 2 works
     x_batch = KTF.ones(shape=(32, 20))
     y_batch = KTF.ones(shape=(20, 32))
     xy_batch_dot = KTF.batch_dot(x_batch, y_batch, axes=(0, 1))
     assert_allclose(KTF.eval(xy_batch_dot), np.ones((20, 1)) * 32, atol=1e-05)
     xy_batch_dot = KTF.batch_dot(x_batch, y_batch, axes=(1, 0))
     assert_allclose(KTF.eval(xy_batch_dot), np.ones((32, 1)) * 20, atol=1e-05)
 def _call_multiplicative_emission(self, inputs):
     # e_{t, t'} = x_t^T W_a x_{t'} + b_a
     e = K.batch_dot(K.dot(inputs, self.Wa),
                     K.permute_dimensions(inputs, (0, 2, 1)))
     if self.use_attention_bias:
         e += self.ba[0]
     return e
 def _attention_regularizer(self, attention):
     batch_size = K.cast(K.shape(attention)[0], K.floatx())
     input_len = K.shape(attention)[-1]
     indices = K.expand_dims(K.arange(0, input_len), axis=0)
     diagonal = K.expand_dims(K.arange(0, input_len), axis=-1)
     eye = K.cast(K.equal(indices, diagonal), K.floatx())
     return self.attention_regularizer_weight * K.sum(
         K.square(
             K.batch_dot(attention,
                         K.permute_dimensions(attention, (0, 2, 1))) -
             eye)) / batch_size
    def call(self, inputs, mask=None, **kwargs):
        input_len = K.shape(inputs)[1]

        if self.attention_type == Attention.ATTENTION_TYPE_ADD:
            e = self._call_additive_emission(inputs)
        elif self.attention_type == Attention.ATTENTION_TYPE_MUL:
            e = self._call_multiplicative_emission(inputs)

        if self.attention_activation is not None:
            e = self.attention_activation(e)
        if self.attention_width is not None:
            if self.history_only:
                lower = K.arange(0, input_len) - (self.attention_width - 1)
            else:
                lower = K.arange(0, input_len) - self.attention_width // 2
            lower = K.expand_dims(lower, axis=-1)
            upper = lower + self.attention_width
            indices = K.expand_dims(K.arange(0, input_len), axis=0)
            e -= 10000.0 * (1.0 - K.cast(lower <= indices, K.floatx()) *
                            K.cast(indices < upper, K.floatx()))
        if mask is not None:
            mask = K.expand_dims(K.cast(mask, K.floatx()), axis=-1)
            e -= 10000.0 * ((1.0 - mask) *
                            (1.0 - K.permute_dimensions(mask, (0, 2, 1))))

        # a_{t} = \text{softmax}(e_t)
        e = K.exp(e - K.max(e, axis=-1, keepdims=True))
        a = e / K.sum(e, axis=-1, keepdims=True)

        # l_t = \sum_{t'} a_{t, t'} x_{t'}
        v = K.batch_dot(a, inputs)
        if self.attention_regularizer_weight > 0.0:
            self.add_loss(self._attention_regularizer(a))

        if self.return_attention:
            return [v, a]
        return v
Example #6
0
 def test_batch_dot_shape(self):
     with pytest.raises(ValueError):
         x_batch = KTF.ones(shape=(32, 20))
         y_batch = KTF.ones(shape=(32, 20))
         xy_batch_dot = KTF.batch_dot(x_batch, y_batch, axes=1)
def get_R(X):
    U_w, vecT = X[0], X[1]
    print U_w.shape, vecT.shape
    ans = K.batch_dot(U_w, vecT)
    return ans
Example #8
0
 def test_batch_dot_shape(self):
     with pytest.raises(ValueError):
         x_batch = KTF.ones(shape=(32, 20))
         y_batch = KTF.ones(shape=(32, 20))
         xy_batch_dot = KTF.batch_dot(x_batch, y_batch, axes=1)