Example #1
0
 def arcface_layer(self, inputs, labels, output_num, weights):
   ''' ArcFace layer. '''
   params = self.netconf['arcface_params']
   s = params['scale']
   m = params['margin']
   limit_to_pi = params['limit_to_pi']
   return arcface_loss(
       inputs, labels, output_num, weights, s=s, m=m, limit_to_pi=limit_to_pi)
Example #2
0
  def test_arcface_loss(self):
    ''' test arcface loss '''

    def gen_fake_data(batch_size, embedding_size, num_spks):
      ''' generate fake embeddings and labels '''
      assert batch_size == embedding_size
      assert num_spks == embedding_size
      embeddings = np.eye(batch_size, dtype='float32')
      labels = np.zeros((batch_size,), dtype='int32')
      for spk in range(batch_size):
        labels[spk] = spk
      return embeddings, labels

    with self.cached_session():
      batch_size = 4
      embedding_size = 4
      num_spks = 4
      embeddings, labels = gen_fake_data(batch_size, embedding_size, num_spks)
      weights = embeddings  # use whatever data is (somehow) trivial

      weights_tensor = tf.constant(weights)
      embeddings_tensor = tf.constant(embeddings)
      labels_tensor = tf.constant(labels)

      output_true = np.asarray(
          [[56.165283, 0., 0., 0.], [0., 56.165283, 0., 0.],
           [0., 0., 56.165283, 0.], [0., 0., 0., 56.165283]],
          dtype='float32')
      output = loss_utils.arcface_loss(
          embeddings_tensor,
          labels_tensor,
          num_spks,
          weights_tensor,
          s=64.0,
          m=0.5,
          limit_to_pi=True)
      self.assertAllClose(output.eval(), output_true)