def sdr(_u_hats, idx, _vs, _v):
     _u_hat = _u_hats[:, idx, :, :, :]
     b = tf.einsum("bmij,bij->bmi", _u_hat, _v)
     c = tf.nn.softmax(b, axis=2)
     s = tf.reduce_sum(_u_hat * tf.expand_dims(c, -1), axis=1)
     _v = squash(s, axis=-1)
     for _ in range(1, self.iter):
         b += tf.einsum("bmij,bij->bmi", _u_hat, _v)
         c = tf.nn.softmax(b, axis=2)
         s = tf.reduce_sum(_u_hat * tf.expand_dims(c, -1),
                           axis=1)
         _v = squash(s, axis=-1)
     _vs = _vs.write(idx, _v)
     return _u_hats, tf.add(idx, 1), _vs, _v
Exemple #2
0
 def body_context(u_hats, idx, vs, v, wgt, bias):
     u_hat = u_hats[:, idx, :, :, :, :]
     u_hat = tf.matmul(wgt, u_hat) + bias
     max_i = tf.shape(u_hat)[1]
     b = tf.matmul(u_hat, tf.tile(v, [1, max_i, 1, 1, 1]), transpose_a=True)
     c = tf.nn.softmax(b, axis=2)
     s = tf.reduce_sum(tf.multiply(c, u_hat), axis=1, keepdims=True)
     v = squash(s, axis=-2)
     vs = vs.write(idx, v)
     return u_hats, tf.add(idx, 1), vs, v, wgt, bias
Exemple #3
0
 def _dr_loop_body(u_hat, b, counter, v, routing_iter, max_i, masking):
     # pylint: disable=unused-argument
     b += masking
     c = tf.nn.softmax(b, axis=3)  # caps2_n, since routing to caps2_n
     s = tf.reduce_sum(tf.multiply(c, u_hat), axis=2, keepdims=True)
     v = squash(s, axis=-2)
     b += tf.matmul(u_hat,
                    tf.tile(v, [1, 1, max_i, 1, 1, 1]),
                    transpose_a=True)
     return u_hat, b, tf.add(counter, 1), v, routing_iter, max_i, masking
 def dr(u_hat, b, counter, v, routing_iter, max_i, masking):
     b += masking
     c = tf.nn.softmax(
         b, axis=3)  # caps2_n, since routing to caps2_n
     s = tf.reduce_sum(tf.multiply(c, u_hat),
                       axis=2,
                       keepdims=True)
     v = squash(s, axis=-2)
     b += tf.matmul(u_hat,
                    tf.tile(v, [1, 1, max_i, 1, 1, 1]),
                    transpose_a=True)
     return u_hat, b, tf.add(counter,
                             1), v, routing_iter, max_i, masking
                def psdr(_u_hats, idx, _vs, _v):
                    bat = tf.shape(_u_hats)[0]
                    exp_i = tf.shape(_u_hats)[2]
                    out_n = tf.shape(_u_hats)[3]
                    pad_mask = tf.concat([
                        tf.ones([bat, exp_i, 1], dtype=tf.float32) * -1e9,
                        tf.zeros([bat, exp_i, out_n - 1], dtype=tf.float32)
                    ], 2)

                    _u_hat = _u_hats[:, idx, :, :, :]
                    b = tf.einsum("bmij,bij->bmi", _u_hat, _v) + pad_mask
                    c = tf.nn.softmax(b, axis=2)
                    s = tf.reduce_sum(_u_hat * tf.expand_dims(c, -1), axis=1)
                    _v = squash(s, axis=-1)
                    for _ in range(1, self.iter):
                        b += tf.einsum("bmij,bij->bmi", _u_hat, _v) + pad_mask
                        c = tf.nn.softmax(b, axis=2)
                        s = tf.reduce_sum(_u_hat * tf.expand_dims(c, -1),
                                          axis=1)
                        _v = squash(s, axis=-1)
                    _vs = _vs.write(idx, _v)
                    return _u_hats, tf.add(idx, 1), _vs, _v
Exemple #6
0
    def pad_body_context(u_hats, idx, vs, v, wgt, bias):
        u_hat = u_hats[:, idx, :, :, :, :]
        u_hat = tf.matmul(wgt, u_hat) + bias
        batch = tf.shape(u_hat)[0]
        max_i = tf.shape(u_hat)[1]
        caps2_n = tf.shape(u_hat)[2]
        masking = tf.concat([
            tf.ones([batch, max_i, 1, 1, 1], dtype=tf.float32) * -1e9,
            tf.zeros([batch, max_i, caps2_n - 1, 1, 1], dtype=tf.float32)
        ], 2)

        b = tf.matmul(u_hat, tf.tile(v, [1, max_i, 1, 1, 1]), transpose_a=True)
        b += masking
        c = tf.nn.softmax(b, axis=2)
        s = tf.reduce_sum(tf.multiply(c, u_hat), axis=1, keepdims=True)
        v = squash(s, axis=-2)
        vs = vs.write(idx, v)
        return u_hats, tf.add(idx, 1), vs, v, wgt, bias
    def call(self, inputs, **kwargs):
        inp_len = kwargs["input_lengths"]
        training = kwargs["training"]

        # Capsulation: feature sequences to primary capsule
        conv_out, batch, seq_len = self.conv(inputs, input_lengths=inp_len)
        emb = tf.reshape(conv_out,
                         [batch, seq_len, self.feat_dim * self.nfilt])
        emb = self.proj_pe(emb)
        emb *= tf.math.sqrt(tf.cast(self.ph, tf.float32))
        emb += mh.get_pos_enc(seq_len, self.ph)
        emb = tf.expand_dims(emb, -1)
        emb = tf.math.maximum(self.ecd[0](self.ecs[0](emb)),
                              self.ecd[1](self.ecs[1](emb)))
        emb = self.mask([emb, inp_len, self.stride**2])
        emb = self.mask_layer(emb)
        emb = tf.reshape(emb, [batch, seq_len, self.ph, self.pd])
        emb = squash(emb, -1)
        emb = tf.reshape(emb, [batch, seq_len, self.ph * self.pd])
        emb = self.ln_i(emb)
        emb = tf.reshape(emb, [batch, seq_len, self.ph, self.pd])
        emb = self.inp_dropout(emb, training=training)

        # Contextual Dynamic Routing
        for i in range(self.enc_num):
            inh = tf.shape(self.wgt[i])[0]
            outh, outd = tf.shape(self.wgt[i])[1], tf.shape(self.wgt[i])[2]

            # windowing
            pemb = tf.keras.layers.ZeroPadding2D(padding=((self.lpad,
                                                           self.rpad),
                                                          (0, 0)))(emb)
            emb = tf.concat([
                pemb[:, i:i + tf.shape(emb)[1], :, :]
                for i in range(self.window)
            ], 2)

            # routing algorithm
            u_hat = tf.einsum('ijkl,bsil->bsijk', self.wgt[i], emb) + \
                    tf.tile(self.bias[i], [batch, seq_len, 1, 1, 1])
            if self.is_context:
                vs = tf.TensorArray(dtype=tf.float32,
                                    infer_shape=False,
                                    size=1,
                                    dynamic_size=True)

                @tf.function
                def psdr(_u_hats, idx, _vs, _v):
                    bat = tf.shape(_u_hats)[0]
                    exp_i = tf.shape(_u_hats)[2]
                    out_n = tf.shape(_u_hats)[3]
                    pad_mask = tf.concat([
                        tf.ones([bat, exp_i, 1], dtype=tf.float32) * -1e9,
                        tf.zeros([bat, exp_i, out_n - 1], dtype=tf.float32)
                    ], 2)

                    _u_hat = _u_hats[:, idx, :, :, :]
                    b = tf.einsum("bmij,bij->bmi", _u_hat, _v) + pad_mask
                    c = tf.nn.softmax(b, axis=2)
                    s = tf.reduce_sum(_u_hat * tf.expand_dims(c, -1), axis=1)
                    _v = squash(s, axis=-1)
                    for _ in range(1, self.iter):
                        b += tf.einsum("bmij,bij->bmi", _u_hat, _v) + pad_mask
                        c = tf.nn.softmax(b, axis=2)
                        s = tf.reduce_sum(_u_hat * tf.expand_dims(c, -1),
                                          axis=1)
                        _v = squash(s, axis=-1)
                    _vs = _vs.write(idx, _v)
                    return _u_hats, tf.add(idx, 1), _vs, _v

                @tf.function
                def sdr(_u_hats, idx, _vs, _v):
                    _u_hat = _u_hats[:, idx, :, :, :]
                    b = tf.einsum("bmij,bij->bmi", _u_hat, _v)
                    c = tf.nn.softmax(b, axis=2)
                    s = tf.reduce_sum(_u_hat * tf.expand_dims(c, -1), axis=1)
                    _v = squash(s, axis=-1)
                    for _ in range(1, self.iter):
                        b += tf.einsum("bmij,bij->bmi", _u_hat, _v)
                        c = tf.nn.softmax(b, axis=2)
                        s = tf.reduce_sum(_u_hat * tf.expand_dims(c, -1),
                                          axis=1)
                        _v = squash(s, axis=-1)
                    _vs = _vs.write(idx, _v)
                    return _u_hats, tf.add(idx, 1), _vs, _v

                _, _, vs, _ = \
                  tf.while_loop(lambda a, b, c, d: tf.less(b, tf.shape(a)[1]),
                                psdr if i == self.enc_num - 1 else sdr,
                                [u_hat, 0, vs, tf.zeros([batch, outh, outd])])
                emb = tf.reshape(vs.concat(), [seq_len, batch, outh, outd])
                emb = tf.transpose(emb, [1, 0, 2, 3])
            else:
                b = tf.zeros([batch, seq_len, inh, outh, 1, 1],
                             dtype=tf.float32)
                if i == self.enc_num - 1:
                    masking = tf.concat([
                        tf.ones([batch, seq_len, inh, 1, 1, 1],
                                dtype=tf.float32) * -1e9,
                        tf.zeros([batch, seq_len, inh, outh - 1, 1, 1],
                                 dtype=tf.float32)
                    ], 3)
                else:
                    masking = tf.zeros([batch, seq_len, inh, outh, 1, 1],
                                       dtype=tf.float32)

                dummy = tf.zeros([batch, seq_len, 1, outh, outd, 1])
                args = [
                    u_hat, b,
                    tf.constant(0), dummy, self.iter, inh, masking
                ]

                @tf.function
                def dr(u_hat, b, counter, v, routing_iter, max_i, masking):
                    b += masking
                    c = tf.nn.softmax(
                        b, axis=3)  # caps2_n, since routing to caps2_n
                    s = tf.reduce_sum(tf.multiply(c, u_hat),
                                      axis=2,
                                      keepdims=True)
                    v = squash(s, axis=-2)
                    b += tf.matmul(u_hat,
                                   tf.tile(v, [1, 1, max_i, 1, 1, 1]),
                                   transpose_a=True)
                    return u_hat, b, tf.add(counter,
                                            1), v, routing_iter, max_i, masking

                _, _, _, emb, _, _, _ = \
                  tf.while_loop(lambda a, b, c, d, r, f, g: tf.less(c, r), dr, args)
                emb = tf.squeeze(emb, [2, 5])

            # layer normalization and dropout
            emb = tf.reshape(emb, [batch, seq_len, outh * outd])
            emb = self.ln_m[i](emb)
            emb = tf.reshape(emb, [batch, seq_len, outh, outd])
            emb = self.mid_dropout[i](emb, training=training)

        return self.ln_o(
            tf.sqrt(tf.reduce_sum(tf.square(emb), axis=-1) + 1e-9))
Exemple #8
0
    def call(self, inputs, **kwargs):
        inp_len = kwargs["input_lengths"]
        training = kwargs["training"]
        lpad = self.lpad
        rpad = self.rpad
        window = self.window
        caps_in_d = tf.shape(self.wgt[0])[4]

        # Capsulation: feature sequences to primary capsule
        conv_out, batch, seq_len = self.conv(inputs, input_lengths=inp_len)

        emb = tf.reshape(conv_out,
                         [batch, seq_len, self.feat_dim * self.nfilt],
                         name="reshape_emb1")
        emb = tf.expand_dims(self.proj_pe(emb), -1)
        emb = tf.math.maximum(self.ecd[0](self.ecs[0](emb)),
                              self.ecd[1](self.ecs[1](emb)))
        emb = self.mask([emb, inp_len, self.stride**2])
        emb = self.mask_layer(emb)
        emb = tf.reshape(emb, [batch, seq_len, self.caps_inp_n, caps_in_d],
                         name="reshape_emb2")
        emb = squash(emb, -1)
        emb = tf.reshape(emb, [batch, seq_len, self.caps_inp_n * caps_in_d],
                         name="reshape_lni1")
        emb = self.ln_i(emb)
        emb = tf.reshape(emb, [batch, seq_len, self.caps_inp_n, caps_in_d],
                         name="reshape_lni2")
        emb = self.inp_dropout(emb, training=training)

        # Contextual Dynamic Routing
        for i in range(self.enc_num):
            caps_in_n, caps_in_d = tf.shape(self.wgt[i])[1], tf.shape(
                self.wgt[i])[4]
            caps_out_n, caps_out_d = tf.shape(self.wgt[i])[2], tf.shape(
                self.wgt[i])[3]

            # windowing
            emb_pad = tf.keras.layers.ZeroPadding2D(padding=((lpad, rpad),
                                                             (0, 0)))(emb)
            emb = tf.concat([
                emb_pad[:, i:i + tf.shape(emb)[1], :, :] for i in range(window)
            ], 2)

            # computing prediction vectors
            caps1_ex = tf.expand_dims(tf.expand_dims(emb, -1), 3)
            u_hat = tf.tile(caps1_ex, [1, 1, 1, caps_out_n, 1, 1])
            wgt = tf.tile(self.wgt[i], [batch, 1, 1, 1, 1])
            bias = tf.tile(self.bias[i], [batch, 1, 1, 1, 1])
            # routing algorithm
            if self.is_context:
                vs = tf.TensorArray(dtype=tf.float32,
                                    infer_shape=False,
                                    size=1,
                                    dynamic_size=True)
                v_zero = tf.zeros([batch, 1, caps_out_n, caps_out_d, 1])
                _srf_loop_body = SequenceRouter.pad_body_context \
                  if i == self.enc_num - 1 else SequenceRouter.body_context
                _, _, vs, _, _, _ = tf.while_loop(
                    SequenceRouter._srf_cond, _srf_loop_body,
                    [u_hat, 0, vs, v_zero, wgt, bias])
                emb = tf.reshape(vs.concat(),
                                 [seq_len, batch, caps_out_n, caps_out_d],
                                 name="reshape_dr")
                emb = tf.transpose(emb, [1, 0, 2, 3])
            else:  # Dynamic Routing
                b = tf.zeros([batch, seq_len, caps_in_n, caps_out_n, 1, 1],
                             dtype=tf.float32)
                if i == self.enc_num - 1:
                    masking = tf.concat([
                        tf.ones([batch, seq_len, caps_in_n, 1, 1, 1],
                                dtype=tf.float32) * -1e9,
                        tf.zeros(
                            [batch, seq_len, caps_in_n, caps_out_n - 1, 1, 1],
                            dtype=tf.float32)
                    ], 3)
                else:
                    masking = tf.zeros(
                        [batch, seq_len, caps_in_n, caps_out_n, 1, 1],
                        dtype=tf.float32)
                dummy = tf.zeros(
                    [batch, seq_len, 1, caps_out_n, caps_out_d, 1])
                args = [u_hat, b, tf.constant(0), dummy, 1, caps_in_n, masking]
                _, _, _, emb, _, _, _ = tf.while_loop(
                    SequenceRouter._dr_cond, SequenceRouter._dr_loop_body,
                    args)
                emb = tf.squeeze(emb, [2, 5])

            # layer normalization and dropout
            emb = tf.reshape(emb, [batch, seq_len, caps_out_n * caps_out_d],
                             name="reshape_lna%d" % (i + 1))
            emb = self.ln_m[i](emb)
            emb = tf.reshape(emb, [batch, seq_len, caps_out_n, caps_out_d],
                             name="reshape_lnb%d" % (i + 1))
            emb = self.mid_dropout[i](emb, training=training)

        return self.ln_o(length(emb, axis=-1))