Exemple #1
0
def KD_SVD(student_feature_maps, teacher_feature_maps):
    '''
    Seung Hyun Lee, Dae Ha Kim, and Byung Cheol Song.
    Self-supervised knowledge distillation using singular value decomposition. In
    European Conference on ComputerVision, pages 339–354. Springer, 2018.
    '''
    with tf.variable_scope('Distillation'):
        GNN_losses = []
        K = 4
        V_Tb = V_Sb = None
        for i, sfm, tfm in zip(range(len(student_feature_maps)), student_feature_maps, teacher_feature_maps):
            with tf.variable_scope('Compress_feature_map%d'%i):
                Sigma_T, U_T, V_T = SVP.SVD(tfm, K, name = 'TSVD%d'%i)
                Sigma_S, U_S, V_S = SVP.SVD(sfm, K, name = 'SSVD%d'%i)
                B, D,_ = V_S.get_shape().as_list()
                V_S, U_S, V_T = SVP.Align_rsv(V_S, V_T, U_S, Sigma_T, K)
                Sigma_T = tf.expand_dims(Sigma_T,1)
                V_T *= Sigma_T
                V_S *= Sigma_T
                
            if i > 0:
                with tf.variable_scope('RBF%d'%i):    
                    S_rbf = tf.exp(-tf.square(tf.expand_dims(V_S,2)-tf.expand_dims(V_Sb,1))/8)
                    T_rbf = tf.exp(-tf.square(tf.expand_dims(V_T,2)-tf.expand_dims(V_Tb,1))/8)

                    l2loss = (S_rbf-tf.stop_gradient(T_rbf))**2
                    l2loss = tf.where(tf.is_finite(l2loss), l2loss, tf.zeros_like(l2loss))
                    GNN_losses.append(tf.reduce_sum(l2loss))
            V_Tb = V_T
            V_Sb = V_S

        transfer_loss =  tf.add_n(GNN_losses)

        return transfer_loss
Exemple #2
0
def MHGD(student_feature_maps, teacher_feature_maps):
    '''
    Seunghyun Lee, Byung Cheol Song.
    Graph-based Knowledge Distillation by Multi-head Self-attention Network.
    British Machine Vision Conference (BMVC) 2019
    '''
    with tf.variable_scope('MHGD'):
        with tf.contrib.framework.arg_scope([tf.contrib.layers.fully_connected], trainable = True,
                                            weights_initializer=tf.initializers.random_normal(),
                                            weights_regularizer=None, variables_collections = [tf.GraphKeys.GLOBAL_VARIABLES,'MHA']):
            with tf.contrib.framework.arg_scope([tf.contrib.layers.batch_norm], activation_fn=None, trainable = True,
                                                param_regularizers = None, variables_collections=[tf.GraphKeys.GLOBAL_VARIABLES,'MHA']):
                GNN_losses = []
                num_head = 8
                V_Tb = V_Sb = None
                num_feat = len(student_feature_maps)
                for i, sfm, tfm in zip(range(num_feat), student_feature_maps, teacher_feature_maps):
                    with tf.variable_scope('Compress_feature_map%d'%i):
                        Sigma_T, U_T, V_T = SVP.SVD_eid(tfm, 1, name = 'TSVD%d'%i)
                        _,       U_S, V_S = SVP.SVD_eid(sfm, 1, name = 'SSVD%d'%i)
                        V_S, mask = SVP.Align_rsv(V_T, V_S)
                        D = V_T.get_shape().as_list()[1]
                    
                        V_T = tf.reshape(V_T,[-1,D])
                        V_S = tf.reshape(V_S,[-1,D])

                    with tf.variable_scope('MHA%d'%i):
                        if i > 0:
                            _,D_, = V_Sb.get_shape().as_list()
                            D2 = (D+D_)//2
                            G_T = Attention_head(V_T, V_Tb, D2, num_head, 'Attention', is_training = True)
                            V_T_ = Estimator(V_Tb, G_T, D, num_head, 'Estimator')
                            tf.add_to_collection('MHA_loss', tf.reduce_mean(1-tf.reduce_sum(V_T_*V_T, -1)) )
                            
                            G_T = Attention_head(V_T, V_Tb, D2, num_head, 'Attention', reuse = True)
                            G_S = Attention_head(V_S, V_Sb, D2, num_head, 'Attention', reuse = True)

                            mean = tf.reduce_mean(G_T, -1, keepdims=True)
                            G_T = tf.tanh(G_T-mean)
                            G_S = tf.tanh(G_S-mean)
                       
                            GNN_losses.append(kld_loss(G_S, G_T))
                            
                    V_Tb = V_T
                    V_Sb = V_S
        
                transfer_loss =  tf.add_n(GNN_losses)
        
                return transfer_loss
Exemple #3
0
def MHGD(student_feature_maps, teacher_feature_maps):
    '''
    Seunghyun Lee, Byung Cheol Song.
    Graph-based Knowledge Distillation by Multi-head Attention Network.
    British Machine Vision Conference (BMVC) 2019
    '''
    with tf.variable_scope('MHGD'):
        GNN_losses = []
        num_head = 8
        V_Tb = V_Sb = None
        num_feat = len(student_feature_maps)
        for i, sfm, tfm in zip(range(num_feat), student_feature_maps, teacher_feature_maps):
            with tf.variable_scope('Compress_feature_map%d'%i):
                Sigma_T, U_T, V_T = SVP.SVD_eid(tfm, 1, name = 'TSVD%d'%i)
                _,       U_S, V_S = SVP.SVD_eid(sfm, 4, name = 'SSVD%d'%i)
                V_S, V_T = SVP.Align_rsv(V_S, V_T)
                D = V_T.get_shape().as_list()[1]
            
                V_T = tf.reshape(V_T,[-1,D])
                V_S = tf.reshape(V_S,[-1,D])

            with tf.variable_scope('MHA%d'%i):
                if i > 0:
                    _,D_, = V_Sb.get_shape().as_list()
                    D2 = (D+D_)//2
                    G_T = Attention_head(V_T, V_Tb, D2, num_head, 'Attention', is_training = True)
                    V_T_ = Estimator(V_Tb, G_T, D, num_head, 'Estimator', is_training = True)
                    tf.add_to_collection('MHA_loss', tf.reduce_mean(1-tf.reduce_sum(V_T_*V_T, -1)) )
                    
                    G_T = Attention_head(V_T, V_Tb, D2, num_head, 'Attention', reuse = True)
                    G_S = Attention_head(V_S, V_Sb, D2, num_head, 'Attention', reuse = True)
                    G_T = tf.tanh(G_T)
                    G_S = tf.tanh(G_S)
               
                    GNN_losses.append(kld_loss(G_S, G_T))
                    
            V_Tb, V_Sb = V_T, V_S

        transfer_loss =  tf.add_n(GNN_losses)

        return transfer_loss
Exemple #4
0
def EKI(student_feature_maps, teacher_feature_maps):
    with tf.variable_scope('EKI'):
        with tf.contrib.framework.arg_scope(
            [tcl.fully_connected],
                weights_regularizer=None,
                variables_collections=[tf.GraphKeys.GLOBAL_VARIABLES, 'MHA']):
            with tf.contrib.framework.arg_scope(
                [tcl.batch_norm],
                    activation_fn=None,
                    param_regularizers=None,
                    variables_collections=[
                        tf.GraphKeys.GLOBAL_VARIABLES, 'MHA'
                    ]):
                GNN_losses0 = []
                GNN_losses1 = []
                GNN_losses2 = []
                T_F = S_F = D_F = None

                for i, (sfm, tfm) in enumerate(
                        zip(student_feature_maps, teacher_feature_maps)):
                    tz = tfm.get_shape().as_list()
                    sz = sfm.get_shape().as_list()
                    D_B = tz[-1]

                    with tf.variable_scope('Stacked_PCA%d' % i):
                        with tf.variable_scope('Teacher_PCA'):
                            Sigma_T, U_T, V_T = SVP.SVD_eid(tfm,
                                                            1,
                                                            name='TSVD')
                            sign = tf.sign(
                                tf.reduce_max(V_T, 1, keepdims=True) +
                                tf.reduce_min(V_T, 1, keepdims=True))
                            V_T *= sign
                            U_T *= sign

                            T_B = tf.reshape(V_T, [-1, D_B])
                            T_B = Plane_mapping(T_B)
                            G_B, T_B, mean_V, P_B = PCA_Graph(T_B)

                        with tf.variable_scope('Student_PCA'):
                            sfm = tf.reshape(sfm, [-1, sz[1] * sz[2], sz[3]])
                            V_S = tf.nn.l2_normalize(
                                tf.matmul(sfm, U_T, transpose_a=True), 1)

                            S_B = tf.reshape(V_S, [-1, D_B])
                            S_B = Plane_mapping(S_B)
                            S_B = tf.matmul(S_B - mean_V, P_B)

                    with tf.variable_scope('EKI_module%d' % i):
                        if i > 0:
                            with tf.variable_scope('MPNN'):
                                D = D_B // 2
                                num_iter = 2
                                G_T, _ = MPNN(T_B, T_F, D, num_iter, 'MPNN')
                                tf.add_to_collection('MHA_loss',
                                                     kld_loss(G_B, G_T))

                                # Update Graph Knowledge
                                G_T, M_T = MPNN(T_B, T_F, D, num_iter, 'MPNN',
                                                False, True)
                                G_S, M_S = MPNN(S_B, S_F, D, num_iter, 'MPNN',
                                                False, True)

                                GNN_losses0.append(
                                    tf.reduce_mean(
                                        tf.reduce_sum(tf.abs(M_S - M_T), -1)))
                                GNN_losses1.append(kld_loss(G_T, G_S))

                            with tf.variable_scope('MHA'):
                                M_T = tf.reduce_mean(M_T, 2) * num_iter
                                GNN_losses2.append(
                                    Attention_knowledge(T_F,
                                                        T_B,
                                                        S_F,
                                                        S_B,
                                                        num_head=8,
                                                        extrinsic=M_T))

                    T_F, S_F, D_F = T_B, S_B, D_B

                tf.add_to_collection('dist', tf.add_n(GNN_losses0))
                tf.add_to_collection('dist', tf.add_n(GNN_losses1))
                tf.add_to_collection('dist', tf.add_n(GNN_losses2))