def arcface_layer(self, inputs, labels, output_num, weights): ''' ArcFace layer. ''' params = self.netconf['arcface_params'] s = params['scale'] m = params['margin'] limit_to_pi = params['limit_to_pi'] return arcface_loss( inputs, labels, output_num, weights, s=s, m=m, limit_to_pi=limit_to_pi)
def test_arcface_loss(self): ''' test arcface loss ''' def gen_fake_data(batch_size, embedding_size, num_spks): ''' generate fake embeddings and labels ''' assert batch_size == embedding_size assert num_spks == embedding_size embeddings = np.eye(batch_size, dtype='float32') labels = np.zeros((batch_size,), dtype='int32') for spk in range(batch_size): labels[spk] = spk return embeddings, labels with self.cached_session(): batch_size = 4 embedding_size = 4 num_spks = 4 embeddings, labels = gen_fake_data(batch_size, embedding_size, num_spks) weights = embeddings # use whatever data is (somehow) trivial weights_tensor = tf.constant(weights) embeddings_tensor = tf.constant(embeddings) labels_tensor = tf.constant(labels) output_true = np.asarray( [[56.165283, 0., 0., 0.], [0., 56.165283, 0., 0.], [0., 0., 56.165283, 0.], [0., 0., 0., 56.165283]], dtype='float32') output = loss_utils.arcface_loss( embeddings_tensor, labels_tensor, num_spks, weights_tensor, s=64.0, m=0.5, limit_to_pi=True) self.assertAllClose(output.eval(), output_true)