Ejemplo n.º 1
0
 def call(self, inputs, mask=None, **kwargs):
     if self.return_masked:
         return [
             inputs[0],
             K.cast(self.compute_mask(inputs, mask)[0], K.floatx())
         ]
     return inputs[0]
Ejemplo n.º 2
0
    def test_mask_loss(self):
        def _loss(y_true, _):
            return K.sum(y_true, axis=-1)

        inputs = [keras.layers.Input((5,)), keras.layers.Input((5,))]
        embed = keras.layers.Embedding(input_dim=2, output_dim=3, mask_zero=True)(inputs[0])
        masked = Masked()([embed, inputs[1]])

        model = keras.models.Model(inputs, masked)
        model.compile(
            optimizer='sgd',
            loss=_loss,
        )

        token_input = np.array([
            [1, 1, 1, 0, 0],
            [1, 1, 1, 1, 0],
        ])
        mask_input = np.array([
            [0, 1, 0, 0, 0],
            [1, 0, 0, 0, 0],
        ])
        outputs = np.arange(30, dtype=K.floatx()).reshape((2, 5, 3))
        actual = model.evaluate([token_input, mask_input], outputs)
        self.assertTrue(np.abs(actual - 6.0) < 1e-6 or np.abs(actual - 30.0) < 1e-6, actual)
Ejemplo n.º 3
0
 def call(self, inputs, mask=None, **kwargs):
     output = K.identity(inputs[0])
     if self.return_masked:
         return [
             output,
             K.cast(self.compute_mask(inputs, mask)[0], K.floatx())
         ]
     return output
 def call(self, inputs, **kwargs):
     inputs, tasks = inputs
     if K.dtype(tasks) != 'int32':
         tasks = K.cast(tasks, 'int32')
     task_embed = K.gather(self.embeddings, tasks)
     if self.mask_zero:
         task_embed = task_embed * K.expand_dims(
             K.cast(K.not_equal(tasks, 0), K.floatx()), axis=-1)
     return inputs + task_embed
Ejemplo n.º 5
0
 def call(self, inputs, **kwargs):
     inputs, tasks = inputs
     if K.dtype(tasks) != 'int32':
         tasks = K.cast(tasks, 'int32')
     task_embed = K.gather(self.embeddings, tasks)
     if self.mask_zero:
         task_embed = task_embed * K.expand_dims(
             K.cast(K.not_equal(tasks, 0), K.floatx()), axis=-1)
     if K.backend() == 'theano':
         task_embed = K.tile(task_embed, (1, K.shape(inputs)[1], 1))
     return inputs + task_embed
Ejemplo n.º 6
0
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]

        t = K.cast(self.iterations, K.floatx()) + 1

        lr = K.switch(
            t <= self.warmup_steps,
            self.lr * (t / self.warmup_steps),
            self.lr *
            (1.0 - K.minimum(t, self.decay_steps) / self.decay_steps),
        )

        lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
                     (1. - K.pow(self.beta_1, t)))

        ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        if self.amsgrad:
            vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        else:
            vhats = [K.zeros(1) for _ in params]
        self.weights = [self.iterations] + ms + vs + vhats

        for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
            m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
            v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
            if self.amsgrad:
                vhat_t = K.maximum(vhat, v_t)
                p_t = m_t / (K.sqrt(vhat_t) + self.epsilon)
                self.updates.append(K.update(vhat, vhat_t))
            else:
                p_t = m_t / (K.sqrt(v_t) + self.epsilon)

            if self.initial_weight_decay > 0.0:
                if self.weight_decay_pattern is None:
                    p_t += self.weight_decay * p
                else:
                    for pattern in self.weight_decay_pattern:
                        if pattern in p.name:
                            p_t += self.weight_decay * p
                            break
            p_t = p - lr_t * p_t

            self.updates.append(K.update(m, m_t))
            self.updates.append(K.update(v, v_t))
            new_p = p_t

            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, new_p))
        return self.updates
Ejemplo n.º 7
0
 def call(self, inputs, mask=None, **kwargs):
     if isinstance(inputs, list):
         query, key, value = inputs
     else:
         query = key = value = inputs
     if isinstance(mask, list):
         mask = mask[1]
     feature_dim = K.shape(query)[-1]
     e = K.batch_dot(query, key, axes=2) / K.sqrt(
         K.cast(feature_dim, dtype=K.floatx()))
     if self.history_only:
         query_len, key_len = K.shape(query)[1], K.shape(key)[1]
         ones = tf.ones((query_len, key_len))
         e -= (ones - tf.matrix_band_part(ones, -1, 0)) * 1e9
     if mask is not None:
         e -= (1.0 - K.cast(K.expand_dims(mask, axis=-2), K.floatx())) * 1e9
     a = keras.activations.softmax(e)
     v = K.batch_dot(a, value)
     if self.return_attention:
         return [v, a]
     return v
 def call(self, inputs, mask=None):
     if mask is not None:
         mask = K.cast(mask, K.floatx())
         inputs -= K.expand_dims((1.0 - mask) * 1e6, axis=-1)
     return K.max(inputs, axis=-2)
Ejemplo n.º 9
0
 def call(self, inputs, mask=None):
     if mask is not None:
         mask = K.cast(mask, K.floatx())
         inputs *= K.expand_dims(mask, axis=-1)
     return super(MaskedConv1D, self).call(inputs)
Ejemplo n.º 10
0
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]

        t = K.cast(self.iterations, K.floatx()) + 1

        lr = K.switch(
            t <= self.warmup_steps,
            self.lr * (t / self.warmup_steps),
            self.min_lr + (self.lr - self.min_lr) *
            (1.0 - K.minimum(t, self.decay_steps) / self.decay_steps),
        )

        lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
                     (1. - K.pow(self.beta_1, t)))

        ms = [
            K.zeros(K.int_shape(p), dtype=K.dtype(p), name='m_{}'.format(i))
            for i, p in enumerate(params)
        ]
        vs = [
            K.zeros(K.int_shape(p), dtype=K.dtype(p), name='v_{}'.format(i))
            for i, p in enumerate(params)
        ]
        if self.amsgrad:
            vhats = [
                K.zeros(K.int_shape(p),
                        dtype=K.dtype(p),
                        name='vh_{}'.format(i)) for i, p in enumerate(params)
            ]
        else:
            vhats = [
                K.zeros(1, dtype=K.dtype(p), name='vh_{}'.format(i))
                for i, p in enumerate(params)
            ]
        self.weights = [self.iterations] + ms + vs + vhats

        for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
            m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
            v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
            if self.amsgrad:
                vhat_t = K.maximum(vhat, v_t)
                p_t = m_t / (K.sqrt(vhat_t) + self.epsilon)
                self.updates.append(K.update(vhat, vhat_t))
            else:
                p_t = m_t / (K.sqrt(v_t) + self.epsilon)

            if self.initial_weight_decay > 0.0:
                if self.weight_decay_pattern is None:
                    p_t += self.weight_decay * p
                else:
                    for pattern in self.weight_decay_pattern:
                        if pattern in p.name:
                            p_t += self.weight_decay * p
                            break

            mult_lr_t = lr_t
            if self.lr_mult is not None:
                key = p.name.split('/')
                if key[0] in self.lr_mult:
                    #print ("going in params : ", p.name, self.lr_mult[key[0]], K.eval(lr_t))
                    mult_lr_t *= self.lr_mult[key[0]]

            p_t = p - mult_lr_t * p_t

            self.updates.append(K.update(m, m_t))
            self.updates.append(K.update(v, v_t))
            new_p = p_t

            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, new_p))
        return self.updates