def KD_SVD(student_feature_maps, teacher_feature_maps): ''' Seung Hyun Lee, Dae Ha Kim, and Byung Cheol Song. Self-supervised knowledge distillation using singular value decomposition. In European Conference on ComputerVision, pages 339–354. Springer, 2018. ''' with tf.variable_scope('Distillation'): GNN_losses = [] K = 4 V_Tb = V_Sb = None for i, sfm, tfm in zip(range(len(student_feature_maps)), student_feature_maps, teacher_feature_maps): with tf.variable_scope('Compress_feature_map%d'%i): Sigma_T, U_T, V_T = SVP.SVD(tfm, K, name = 'TSVD%d'%i) Sigma_S, U_S, V_S = SVP.SVD(sfm, K, name = 'SSVD%d'%i) B, D,_ = V_S.get_shape().as_list() V_S, U_S, V_T = SVP.Align_rsv(V_S, V_T, U_S, Sigma_T, K) Sigma_T = tf.expand_dims(Sigma_T,1) V_T *= Sigma_T V_S *= Sigma_T if i > 0: with tf.variable_scope('RBF%d'%i): S_rbf = tf.exp(-tf.square(tf.expand_dims(V_S,2)-tf.expand_dims(V_Sb,1))/8) T_rbf = tf.exp(-tf.square(tf.expand_dims(V_T,2)-tf.expand_dims(V_Tb,1))/8) l2loss = (S_rbf-tf.stop_gradient(T_rbf))**2 l2loss = tf.where(tf.is_finite(l2loss), l2loss, tf.zeros_like(l2loss)) GNN_losses.append(tf.reduce_sum(l2loss)) V_Tb = V_T V_Sb = V_S transfer_loss = tf.add_n(GNN_losses) return transfer_loss
def MHGD(student_feature_maps, teacher_feature_maps): ''' Seunghyun Lee, Byung Cheol Song. Graph-based Knowledge Distillation by Multi-head Self-attention Network. British Machine Vision Conference (BMVC) 2019 ''' with tf.variable_scope('MHGD'): with tf.contrib.framework.arg_scope([tf.contrib.layers.fully_connected], trainable = True, weights_initializer=tf.initializers.random_normal(), weights_regularizer=None, variables_collections = [tf.GraphKeys.GLOBAL_VARIABLES,'MHA']): with tf.contrib.framework.arg_scope([tf.contrib.layers.batch_norm], activation_fn=None, trainable = True, param_regularizers = None, variables_collections=[tf.GraphKeys.GLOBAL_VARIABLES,'MHA']): GNN_losses = [] num_head = 8 V_Tb = V_Sb = None num_feat = len(student_feature_maps) for i, sfm, tfm in zip(range(num_feat), student_feature_maps, teacher_feature_maps): with tf.variable_scope('Compress_feature_map%d'%i): Sigma_T, U_T, V_T = SVP.SVD_eid(tfm, 1, name = 'TSVD%d'%i) _, U_S, V_S = SVP.SVD_eid(sfm, 1, name = 'SSVD%d'%i) V_S, mask = SVP.Align_rsv(V_T, V_S) D = V_T.get_shape().as_list()[1] V_T = tf.reshape(V_T,[-1,D]) V_S = tf.reshape(V_S,[-1,D]) with tf.variable_scope('MHA%d'%i): if i > 0: _,D_, = V_Sb.get_shape().as_list() D2 = (D+D_)//2 G_T = Attention_head(V_T, V_Tb, D2, num_head, 'Attention', is_training = True) V_T_ = Estimator(V_Tb, G_T, D, num_head, 'Estimator') tf.add_to_collection('MHA_loss', tf.reduce_mean(1-tf.reduce_sum(V_T_*V_T, -1)) ) G_T = Attention_head(V_T, V_Tb, D2, num_head, 'Attention', reuse = True) G_S = Attention_head(V_S, V_Sb, D2, num_head, 'Attention', reuse = True) mean = tf.reduce_mean(G_T, -1, keepdims=True) G_T = tf.tanh(G_T-mean) G_S = tf.tanh(G_S-mean) GNN_losses.append(kld_loss(G_S, G_T)) V_Tb = V_T V_Sb = V_S transfer_loss = tf.add_n(GNN_losses) return transfer_loss
def MHGD(student_feature_maps, teacher_feature_maps): ''' Seunghyun Lee, Byung Cheol Song. Graph-based Knowledge Distillation by Multi-head Attention Network. British Machine Vision Conference (BMVC) 2019 ''' with tf.variable_scope('MHGD'): GNN_losses = [] num_head = 8 V_Tb = V_Sb = None num_feat = len(student_feature_maps) for i, sfm, tfm in zip(range(num_feat), student_feature_maps, teacher_feature_maps): with tf.variable_scope('Compress_feature_map%d'%i): Sigma_T, U_T, V_T = SVP.SVD_eid(tfm, 1, name = 'TSVD%d'%i) _, U_S, V_S = SVP.SVD_eid(sfm, 4, name = 'SSVD%d'%i) V_S, V_T = SVP.Align_rsv(V_S, V_T) D = V_T.get_shape().as_list()[1] V_T = tf.reshape(V_T,[-1,D]) V_S = tf.reshape(V_S,[-1,D]) with tf.variable_scope('MHA%d'%i): if i > 0: _,D_, = V_Sb.get_shape().as_list() D2 = (D+D_)//2 G_T = Attention_head(V_T, V_Tb, D2, num_head, 'Attention', is_training = True) V_T_ = Estimator(V_Tb, G_T, D, num_head, 'Estimator', is_training = True) tf.add_to_collection('MHA_loss', tf.reduce_mean(1-tf.reduce_sum(V_T_*V_T, -1)) ) G_T = Attention_head(V_T, V_Tb, D2, num_head, 'Attention', reuse = True) G_S = Attention_head(V_S, V_Sb, D2, num_head, 'Attention', reuse = True) G_T = tf.tanh(G_T) G_S = tf.tanh(G_S) GNN_losses.append(kld_loss(G_S, G_T)) V_Tb, V_Sb = V_T, V_S transfer_loss = tf.add_n(GNN_losses) return transfer_loss
def EKI(student_feature_maps, teacher_feature_maps): with tf.variable_scope('EKI'): with tf.contrib.framework.arg_scope( [tcl.fully_connected], weights_regularizer=None, variables_collections=[tf.GraphKeys.GLOBAL_VARIABLES, 'MHA']): with tf.contrib.framework.arg_scope( [tcl.batch_norm], activation_fn=None, param_regularizers=None, variables_collections=[ tf.GraphKeys.GLOBAL_VARIABLES, 'MHA' ]): GNN_losses0 = [] GNN_losses1 = [] GNN_losses2 = [] T_F = S_F = D_F = None for i, (sfm, tfm) in enumerate( zip(student_feature_maps, teacher_feature_maps)): tz = tfm.get_shape().as_list() sz = sfm.get_shape().as_list() D_B = tz[-1] with tf.variable_scope('Stacked_PCA%d' % i): with tf.variable_scope('Teacher_PCA'): Sigma_T, U_T, V_T = SVP.SVD_eid(tfm, 1, name='TSVD') sign = tf.sign( tf.reduce_max(V_T, 1, keepdims=True) + tf.reduce_min(V_T, 1, keepdims=True)) V_T *= sign U_T *= sign T_B = tf.reshape(V_T, [-1, D_B]) T_B = Plane_mapping(T_B) G_B, T_B, mean_V, P_B = PCA_Graph(T_B) with tf.variable_scope('Student_PCA'): sfm = tf.reshape(sfm, [-1, sz[1] * sz[2], sz[3]]) V_S = tf.nn.l2_normalize( tf.matmul(sfm, U_T, transpose_a=True), 1) S_B = tf.reshape(V_S, [-1, D_B]) S_B = Plane_mapping(S_B) S_B = tf.matmul(S_B - mean_V, P_B) with tf.variable_scope('EKI_module%d' % i): if i > 0: with tf.variable_scope('MPNN'): D = D_B // 2 num_iter = 2 G_T, _ = MPNN(T_B, T_F, D, num_iter, 'MPNN') tf.add_to_collection('MHA_loss', kld_loss(G_B, G_T)) # Update Graph Knowledge G_T, M_T = MPNN(T_B, T_F, D, num_iter, 'MPNN', False, True) G_S, M_S = MPNN(S_B, S_F, D, num_iter, 'MPNN', False, True) GNN_losses0.append( tf.reduce_mean( tf.reduce_sum(tf.abs(M_S - M_T), -1))) GNN_losses1.append(kld_loss(G_T, G_S)) with tf.variable_scope('MHA'): M_T = tf.reduce_mean(M_T, 2) * num_iter GNN_losses2.append( Attention_knowledge(T_F, T_B, S_F, S_B, num_head=8, extrinsic=M_T)) T_F, S_F, D_F = T_B, S_B, D_B tf.add_to_collection('dist', tf.add_n(GNN_losses0)) tf.add_to_collection('dist', tf.add_n(GNN_losses1)) tf.add_to_collection('dist', tf.add_n(GNN_losses2))