コード例 #1
0
ファイル: layers.py プロジェクト: yynnxu/Macadam
    def call(self, u_vecs):
        if self.share_weights:
            u_hat_vecs = K.conv1d(u_vecs, self.W)
        else:
            u_hat_vecs = K.local_conv1d(u_vecs, self.W, [1], [1])

        batch_size = K.shape(u_vecs)[0]
        input_num_capsule = K.shape(u_vecs)[1]
        u_hat_vecs = K.reshape(u_hat_vecs,
                               (batch_size, input_num_capsule,
                                self.num_capsule, self.dim_capsule))
        u_hat_vecs = K.permute_dimensions(u_hat_vecs, (0, 2, 1, 3))
        # final u_hat_vecs.shape = [None, num_capsule, input_num_capsule, dim_capsule]

        b = K.zeros_like(
            u_hat_vecs[:, :, :,
                       0])  # shape = [None, num_capsule, input_num_capsule]
        outputs = None
        for i in range(self.routings):
            b = K.permute_dimensions(
                b, (0, 2, 1))  # shape = [None, input_num_capsule, num_capsule]
            c = K.softmax(b)
            c = K.permute_dimensions(c, (0, 2, 1))
            b = K.permute_dimensions(b, (0, 2, 1))
            outputs = self.activation(K.batch_dot(c, u_hat_vecs, [2, 2]))
            if i < self.routings - 1:
                b = K.batch_dot(outputs, u_hat_vecs, [2, 3])

        return outputs
コード例 #2
0
ファイル: layers.py プロジェクト: yynnxu/Macadam
 def call(self, x):
     WQ = K.dot(x, self.kernel[0])
     WK = K.dot(x, self.kernel[1])
     WV = K.dot(x, self.kernel[2])
     # print("WQ.shape",WQ.shape)
     # print("K.permute_dimensions(WK, [0, 2, 1]).shape",K.permute_dimensions(WK, [0, 2, 1]).shape)
     QK = K.batch_dot(WQ, K.permute_dimensions(WK, [0, 2, 1]))
     QK = QK / (64**0.5)
     QK = K.softmax(QK)
     # print("QK.shape",QK.shape)
     V = K.batch_dot(QK, WV)
     return V