Пример #1
0
    def test_nt_xent_loss_equals_sup_con_loss(self):
        l1 = supervised_nt_xent_loss(tf.constant(self.X),
                                     tf.constant(self.y),
                                     temperature=self.temperature,
                                     base_temperature=self.base_temperature)

        scl = SupConLoss(temperature=self.temperature,
                         base_temperature=self.base_temperature)
        l2 = scl.forward(features=torch.Tensor(
            self.X.reshape(self.batch_size, 1, 128)),
                         labels=torch.Tensor(self.y))
        print('\nLosses from normal batch size={}:'.format(self.batch_size))
        print('l1 = {}'.format(l1.numpy()))
        print('l2 = {}'.format(l2.numpy()))
        self.assertTrue(np.allclose(l1.numpy(), l2.numpy()))
def train_step_sup_nt_xent(x, y):
    '''
    x: data tensor, shape: (batch_size, data_dim)
    y: data labels, shape: (batch_size, )
    '''
    with tf.GradientTape() as tape:
        r = encoder(x, training=True)
        z = projector(r, training=True)
        loss = losses.supervised_nt_xent_loss(z, y, temperature=0.1)

    gradients = tape.gradient(
        loss, encoder.trainable_variables + projector.trainable_variables)
    optimizer.apply_gradients(
        zip(gradients,
            encoder.trainable_variables + projector.trainable_variables))
    train_loss(loss)
Пример #3
0
    def test_nt_xent_loss_and_sup_con_loss_small_batch(self):
        # on very small batch, the SupConLoss would return NaN
        # whereas supervised_nt_xent_loss will ignore those classes
        l1 = supervised_nt_xent_loss(tf.constant(self.X_s),
                                     tf.constant(self.y_s),
                                     temperature=self.temperature,
                                     base_temperature=self.base_temperature)

        scl = SupConLoss(temperature=self.temperature,
                         base_temperature=self.base_temperature)
        l2 = scl.forward(features=torch.Tensor(
            self.X_s.reshape(self.batch_size_s, 1, 128)),
                         labels=torch.Tensor(self.y_s))
        print('\nLosses from small batch size={}:'.format(self.batch_size_s))
        print('l1 = {}'.format(l1.numpy()))
        print('l2 = {}'.format(l2.numpy()))
        self.assertTrue(np.isfinite(l1.numpy()))
        self.assertTrue(np.isnan(l2.numpy()))
Пример #4
0
 def loss_func(z, y):
     return losses.supervised_nt_xent_loss(
         z,
         y,
         temperature=args.temperature,
         base_temperature=args.base_temperature)
def test_step_sup_nt_xent(x, y):
    r = encoder(x, training=False)
    z = projector(r, training=False)
    t_loss = losses.supervised_nt_xent_loss(z, y, temperature=0.1)
    test_loss(t_loss)