def sdr(_u_hats, idx, _vs, _v): _u_hat = _u_hats[:, idx, :, :, :] b = tf.einsum("bmij,bij->bmi", _u_hat, _v) c = tf.nn.softmax(b, axis=2) s = tf.reduce_sum(_u_hat * tf.expand_dims(c, -1), axis=1) _v = squash(s, axis=-1) for _ in range(1, self.iter): b += tf.einsum("bmij,bij->bmi", _u_hat, _v) c = tf.nn.softmax(b, axis=2) s = tf.reduce_sum(_u_hat * tf.expand_dims(c, -1), axis=1) _v = squash(s, axis=-1) _vs = _vs.write(idx, _v) return _u_hats, tf.add(idx, 1), _vs, _v
def body_context(u_hats, idx, vs, v, wgt, bias): u_hat = u_hats[:, idx, :, :, :, :] u_hat = tf.matmul(wgt, u_hat) + bias max_i = tf.shape(u_hat)[1] b = tf.matmul(u_hat, tf.tile(v, [1, max_i, 1, 1, 1]), transpose_a=True) c = tf.nn.softmax(b, axis=2) s = tf.reduce_sum(tf.multiply(c, u_hat), axis=1, keepdims=True) v = squash(s, axis=-2) vs = vs.write(idx, v) return u_hats, tf.add(idx, 1), vs, v, wgt, bias
def _dr_loop_body(u_hat, b, counter, v, routing_iter, max_i, masking): # pylint: disable=unused-argument b += masking c = tf.nn.softmax(b, axis=3) # caps2_n, since routing to caps2_n s = tf.reduce_sum(tf.multiply(c, u_hat), axis=2, keepdims=True) v = squash(s, axis=-2) b += tf.matmul(u_hat, tf.tile(v, [1, 1, max_i, 1, 1, 1]), transpose_a=True) return u_hat, b, tf.add(counter, 1), v, routing_iter, max_i, masking
def dr(u_hat, b, counter, v, routing_iter, max_i, masking): b += masking c = tf.nn.softmax( b, axis=3) # caps2_n, since routing to caps2_n s = tf.reduce_sum(tf.multiply(c, u_hat), axis=2, keepdims=True) v = squash(s, axis=-2) b += tf.matmul(u_hat, tf.tile(v, [1, 1, max_i, 1, 1, 1]), transpose_a=True) return u_hat, b, tf.add(counter, 1), v, routing_iter, max_i, masking
def psdr(_u_hats, idx, _vs, _v): bat = tf.shape(_u_hats)[0] exp_i = tf.shape(_u_hats)[2] out_n = tf.shape(_u_hats)[3] pad_mask = tf.concat([ tf.ones([bat, exp_i, 1], dtype=tf.float32) * -1e9, tf.zeros([bat, exp_i, out_n - 1], dtype=tf.float32) ], 2) _u_hat = _u_hats[:, idx, :, :, :] b = tf.einsum("bmij,bij->bmi", _u_hat, _v) + pad_mask c = tf.nn.softmax(b, axis=2) s = tf.reduce_sum(_u_hat * tf.expand_dims(c, -1), axis=1) _v = squash(s, axis=-1) for _ in range(1, self.iter): b += tf.einsum("bmij,bij->bmi", _u_hat, _v) + pad_mask c = tf.nn.softmax(b, axis=2) s = tf.reduce_sum(_u_hat * tf.expand_dims(c, -1), axis=1) _v = squash(s, axis=-1) _vs = _vs.write(idx, _v) return _u_hats, tf.add(idx, 1), _vs, _v
def pad_body_context(u_hats, idx, vs, v, wgt, bias): u_hat = u_hats[:, idx, :, :, :, :] u_hat = tf.matmul(wgt, u_hat) + bias batch = tf.shape(u_hat)[0] max_i = tf.shape(u_hat)[1] caps2_n = tf.shape(u_hat)[2] masking = tf.concat([ tf.ones([batch, max_i, 1, 1, 1], dtype=tf.float32) * -1e9, tf.zeros([batch, max_i, caps2_n - 1, 1, 1], dtype=tf.float32) ], 2) b = tf.matmul(u_hat, tf.tile(v, [1, max_i, 1, 1, 1]), transpose_a=True) b += masking c = tf.nn.softmax(b, axis=2) s = tf.reduce_sum(tf.multiply(c, u_hat), axis=1, keepdims=True) v = squash(s, axis=-2) vs = vs.write(idx, v) return u_hats, tf.add(idx, 1), vs, v, wgt, bias
def call(self, inputs, **kwargs): inp_len = kwargs["input_lengths"] training = kwargs["training"] # Capsulation: feature sequences to primary capsule conv_out, batch, seq_len = self.conv(inputs, input_lengths=inp_len) emb = tf.reshape(conv_out, [batch, seq_len, self.feat_dim * self.nfilt]) emb = self.proj_pe(emb) emb *= tf.math.sqrt(tf.cast(self.ph, tf.float32)) emb += mh.get_pos_enc(seq_len, self.ph) emb = tf.expand_dims(emb, -1) emb = tf.math.maximum(self.ecd[0](self.ecs[0](emb)), self.ecd[1](self.ecs[1](emb))) emb = self.mask([emb, inp_len, self.stride**2]) emb = self.mask_layer(emb) emb = tf.reshape(emb, [batch, seq_len, self.ph, self.pd]) emb = squash(emb, -1) emb = tf.reshape(emb, [batch, seq_len, self.ph * self.pd]) emb = self.ln_i(emb) emb = tf.reshape(emb, [batch, seq_len, self.ph, self.pd]) emb = self.inp_dropout(emb, training=training) # Contextual Dynamic Routing for i in range(self.enc_num): inh = tf.shape(self.wgt[i])[0] outh, outd = tf.shape(self.wgt[i])[1], tf.shape(self.wgt[i])[2] # windowing pemb = tf.keras.layers.ZeroPadding2D(padding=((self.lpad, self.rpad), (0, 0)))(emb) emb = tf.concat([ pemb[:, i:i + tf.shape(emb)[1], :, :] for i in range(self.window) ], 2) # routing algorithm u_hat = tf.einsum('ijkl,bsil->bsijk', self.wgt[i], emb) + \ tf.tile(self.bias[i], [batch, seq_len, 1, 1, 1]) if self.is_context: vs = tf.TensorArray(dtype=tf.float32, infer_shape=False, size=1, dynamic_size=True) @tf.function def psdr(_u_hats, idx, _vs, _v): bat = tf.shape(_u_hats)[0] exp_i = tf.shape(_u_hats)[2] out_n = tf.shape(_u_hats)[3] pad_mask = tf.concat([ tf.ones([bat, exp_i, 1], dtype=tf.float32) * -1e9, tf.zeros([bat, exp_i, out_n - 1], dtype=tf.float32) ], 2) _u_hat = _u_hats[:, idx, :, :, :] b = tf.einsum("bmij,bij->bmi", _u_hat, _v) + pad_mask c = tf.nn.softmax(b, axis=2) s = tf.reduce_sum(_u_hat * tf.expand_dims(c, -1), axis=1) _v = squash(s, axis=-1) for _ in range(1, self.iter): b += tf.einsum("bmij,bij->bmi", _u_hat, _v) + pad_mask c = tf.nn.softmax(b, axis=2) s = tf.reduce_sum(_u_hat * tf.expand_dims(c, -1), axis=1) _v = squash(s, axis=-1) _vs = _vs.write(idx, _v) return _u_hats, tf.add(idx, 1), _vs, _v @tf.function def sdr(_u_hats, idx, _vs, _v): _u_hat = _u_hats[:, idx, :, :, :] b = tf.einsum("bmij,bij->bmi", _u_hat, _v) c = tf.nn.softmax(b, axis=2) s = tf.reduce_sum(_u_hat * tf.expand_dims(c, -1), axis=1) _v = squash(s, axis=-1) for _ in range(1, self.iter): b += tf.einsum("bmij,bij->bmi", _u_hat, _v) c = tf.nn.softmax(b, axis=2) s = tf.reduce_sum(_u_hat * tf.expand_dims(c, -1), axis=1) _v = squash(s, axis=-1) _vs = _vs.write(idx, _v) return _u_hats, tf.add(idx, 1), _vs, _v _, _, vs, _ = \ tf.while_loop(lambda a, b, c, d: tf.less(b, tf.shape(a)[1]), psdr if i == self.enc_num - 1 else sdr, [u_hat, 0, vs, tf.zeros([batch, outh, outd])]) emb = tf.reshape(vs.concat(), [seq_len, batch, outh, outd]) emb = tf.transpose(emb, [1, 0, 2, 3]) else: b = tf.zeros([batch, seq_len, inh, outh, 1, 1], dtype=tf.float32) if i == self.enc_num - 1: masking = tf.concat([ tf.ones([batch, seq_len, inh, 1, 1, 1], dtype=tf.float32) * -1e9, tf.zeros([batch, seq_len, inh, outh - 1, 1, 1], dtype=tf.float32) ], 3) else: masking = tf.zeros([batch, seq_len, inh, outh, 1, 1], dtype=tf.float32) dummy = tf.zeros([batch, seq_len, 1, outh, outd, 1]) args = [ u_hat, b, tf.constant(0), dummy, self.iter, inh, masking ] @tf.function def dr(u_hat, b, counter, v, routing_iter, max_i, masking): b += masking c = tf.nn.softmax( b, axis=3) # caps2_n, since routing to caps2_n s = tf.reduce_sum(tf.multiply(c, u_hat), axis=2, keepdims=True) v = squash(s, axis=-2) b += tf.matmul(u_hat, tf.tile(v, [1, 1, max_i, 1, 1, 1]), transpose_a=True) return u_hat, b, tf.add(counter, 1), v, routing_iter, max_i, masking _, _, _, emb, _, _, _ = \ tf.while_loop(lambda a, b, c, d, r, f, g: tf.less(c, r), dr, args) emb = tf.squeeze(emb, [2, 5]) # layer normalization and dropout emb = tf.reshape(emb, [batch, seq_len, outh * outd]) emb = self.ln_m[i](emb) emb = tf.reshape(emb, [batch, seq_len, outh, outd]) emb = self.mid_dropout[i](emb, training=training) return self.ln_o( tf.sqrt(tf.reduce_sum(tf.square(emb), axis=-1) + 1e-9))
def call(self, inputs, **kwargs): inp_len = kwargs["input_lengths"] training = kwargs["training"] lpad = self.lpad rpad = self.rpad window = self.window caps_in_d = tf.shape(self.wgt[0])[4] # Capsulation: feature sequences to primary capsule conv_out, batch, seq_len = self.conv(inputs, input_lengths=inp_len) emb = tf.reshape(conv_out, [batch, seq_len, self.feat_dim * self.nfilt], name="reshape_emb1") emb = tf.expand_dims(self.proj_pe(emb), -1) emb = tf.math.maximum(self.ecd[0](self.ecs[0](emb)), self.ecd[1](self.ecs[1](emb))) emb = self.mask([emb, inp_len, self.stride**2]) emb = self.mask_layer(emb) emb = tf.reshape(emb, [batch, seq_len, self.caps_inp_n, caps_in_d], name="reshape_emb2") emb = squash(emb, -1) emb = tf.reshape(emb, [batch, seq_len, self.caps_inp_n * caps_in_d], name="reshape_lni1") emb = self.ln_i(emb) emb = tf.reshape(emb, [batch, seq_len, self.caps_inp_n, caps_in_d], name="reshape_lni2") emb = self.inp_dropout(emb, training=training) # Contextual Dynamic Routing for i in range(self.enc_num): caps_in_n, caps_in_d = tf.shape(self.wgt[i])[1], tf.shape( self.wgt[i])[4] caps_out_n, caps_out_d = tf.shape(self.wgt[i])[2], tf.shape( self.wgt[i])[3] # windowing emb_pad = tf.keras.layers.ZeroPadding2D(padding=((lpad, rpad), (0, 0)))(emb) emb = tf.concat([ emb_pad[:, i:i + tf.shape(emb)[1], :, :] for i in range(window) ], 2) # computing prediction vectors caps1_ex = tf.expand_dims(tf.expand_dims(emb, -1), 3) u_hat = tf.tile(caps1_ex, [1, 1, 1, caps_out_n, 1, 1]) wgt = tf.tile(self.wgt[i], [batch, 1, 1, 1, 1]) bias = tf.tile(self.bias[i], [batch, 1, 1, 1, 1]) # routing algorithm if self.is_context: vs = tf.TensorArray(dtype=tf.float32, infer_shape=False, size=1, dynamic_size=True) v_zero = tf.zeros([batch, 1, caps_out_n, caps_out_d, 1]) _srf_loop_body = SequenceRouter.pad_body_context \ if i == self.enc_num - 1 else SequenceRouter.body_context _, _, vs, _, _, _ = tf.while_loop( SequenceRouter._srf_cond, _srf_loop_body, [u_hat, 0, vs, v_zero, wgt, bias]) emb = tf.reshape(vs.concat(), [seq_len, batch, caps_out_n, caps_out_d], name="reshape_dr") emb = tf.transpose(emb, [1, 0, 2, 3]) else: # Dynamic Routing b = tf.zeros([batch, seq_len, caps_in_n, caps_out_n, 1, 1], dtype=tf.float32) if i == self.enc_num - 1: masking = tf.concat([ tf.ones([batch, seq_len, caps_in_n, 1, 1, 1], dtype=tf.float32) * -1e9, tf.zeros( [batch, seq_len, caps_in_n, caps_out_n - 1, 1, 1], dtype=tf.float32) ], 3) else: masking = tf.zeros( [batch, seq_len, caps_in_n, caps_out_n, 1, 1], dtype=tf.float32) dummy = tf.zeros( [batch, seq_len, 1, caps_out_n, caps_out_d, 1]) args = [u_hat, b, tf.constant(0), dummy, 1, caps_in_n, masking] _, _, _, emb, _, _, _ = tf.while_loop( SequenceRouter._dr_cond, SequenceRouter._dr_loop_body, args) emb = tf.squeeze(emb, [2, 5]) # layer normalization and dropout emb = tf.reshape(emb, [batch, seq_len, caps_out_n * caps_out_d], name="reshape_lna%d" % (i + 1)) emb = self.ln_m[i](emb) emb = tf.reshape(emb, [batch, seq_len, caps_out_n, caps_out_d], name="reshape_lnb%d" % (i + 1)) emb = self.mid_dropout[i](emb, training=training) return self.ln_o(length(emb, axis=-1))