def test_nbdt_gradient_soft(resnet18_cifar10, input_cifar10, label_cifar10, criterion):
    output_cifar10 = resnet18_cifar10(input_cifar10)
    assert output_cifar10.requires_grad

    criterion = SoftTreeSupLoss(
        dataset="CIFAR10", criterion=criterion, hierarchy="induced"
    )
    loss = criterion(output_cifar10, label_cifar10)
    loss.backward()
Example #2
0
def test_criterion_cifar10(criterion, label_cifar10):
    criterion = SoftTreeSupLoss(dataset='CIFAR10',
                                criterion=criterion,
                                hierarchy='induced')
    criterion(torch.randn((1, 10)), label_cifar10)
Example #3
0
def test_criterion_tinyimagenet200(criterion):
    criterion = SoftTreeSupLoss(dataset='TinyImagenet200',
                                criterion=criterion,
                                hierarchy='induced')
    criterion(torch.randn((1, 200)), torch.randint(200, (1, )))
Example #4
0
def test_criterion_cifar100(criterion):
    criterion = SoftTreeSupLoss(dataset='CIFAR100',
                                criterion=criterion,
                                hierarchy='induced')
    criterion(torch.randn((1, 100)), torch.randint(100, (1, )))