Beispiel #1
0
def run():
    generator = Generator().to(device)

    teacher = torch.load(opt.teacher_dir + 'teacher').to(device)
    teacher.eval()
    criterion = torch.nn.CrossEntropyLoss().to(device)

    teacher = nn.DataParallel(teacher)
    generator = nn.DataParallel(generator)

    def kdloss(y, teacher_scores):
        p = F.log_softmax(y, dim=1)
        q = F.softmax(teacher_scores, dim=1)
        l_kl = F.kl_div(p, q, size_average=False)  / y.shape[0]
        return l_kl

    if opt.dataset == 'MNIST':
        # Configure data loader
        net = LeNet5Half().to(device)
        net = nn.DataParallel(net)
        data_test = MNIST(opt.data,
                          train=False,
                          transform=transforms.Compose([
                              transforms.Resize((32, 32)),
                              transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,))
                              ]))
        data_test_loader = DataLoader(data_test, batch_size=64, num_workers=1, shuffle=False)

        # Optimizers
        optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_G)
        optimizer_S = torch.optim.Adam(net.parameters(), lr=opt.lr_S)

    if opt.dataset != 'MNIST':
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        if opt.dataset == 'cifar10':
            net = resnet.ResNet18().to(device)
            net = nn.DataParallel(net)
            data_test = CIFAR10(opt.data,
                              train=False,
                              transform=transform_test)
        if opt.dataset == 'cifar100':
            net = resnet.ResNet18(num_classes=100).to(device)
            net = nn.DataParallel(net)
            data_test = CIFAR100(opt.data,
                              train=False,
                              transform=transform_test)
        data_test_loader = DataLoader(data_test, batch_size=opt.batch_size, num_workers=0)

        # Optimizers
        optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_G)

        optimizer_S = torch.optim.SGD(net.parameters(), lr=opt.lr_S, momentum=0.9, weight_decay=5e-4)


    def adjust_learning_rate(optimizer, epoch, learing_rate):
        if epoch < 800:
            lr = learing_rate
        elif epoch < 1600:
            lr = 0.1*learing_rate
        else:
            lr = 0.01*learing_rate
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr


    # ----------
    #  Training
    # ----------

    batches_done = 0
    accr_best = 0
    for epoch in range(opt.n_epochs):

        total_correct = 0
        avg_loss = 0.0
        if opt.dataset != 'MNIST':
            adjust_learning_rate(optimizer_S, epoch, opt.lr_S)

        for i in range(120):
            net.train()
            z = Variable(torch.randn(opt.batch_size, opt.latent_dim)).to(device)
            optimizer_G.zero_grad()
            optimizer_S.zero_grad()
            gen_imgs = generator(z)
            outputs_T, features_T = teacher(gen_imgs, out_feature=True)
            pred = outputs_T.data.max(1)[1]
            loss_activation = -features_T.abs().mean()
            loss_one_hot = criterion(outputs_T,pred)
            softmax_o_T = torch.nn.functional.softmax(outputs_T, dim = 1).mean(dim = 0)
            loss_information_entropy = (softmax_o_T * torch.log(softmax_o_T)).sum()
            loss = loss_one_hot * opt.oh + loss_information_entropy * opt.ie + loss_activation * opt.a
            loss_kd = kdloss(net(gen_imgs.detach()), outputs_T.detach())
            loss += loss_kd
            loss.backward()
            optimizer_G.step()
            optimizer_S.step()
            if i == 1:
                print ("[Epoch %d/%d] [loss_oh: %f] [loss_ie: %f] [loss_a: %f] [loss_kd: %f]" % (epoch, opt.n_epochs,loss_one_hot.item(), loss_information_entropy.item(), loss_activation.item(), loss_kd.item()))

        with torch.no_grad():
            for i, (images, labels) in enumerate(data_test_loader):
                images = images.to(device)
                labels = labels.to(device)
                net.eval()
                output = net(images)
                avg_loss += criterion(output, labels).sum()
                pred = output.data.max(1)[1]
                total_correct += pred.eq(labels.data.view_as(pred)).sum()

        avg_loss /= len(data_test)
        print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.data.item(), float(total_correct) / len(data_test)))
        accr = round(float(total_correct) / len(data_test), 4)
        if accr > accr_best:
            torch.save(net,opt.output_dir + 'student')
            torch.save(generator.state_dict(), opt.output_dir + "generator.pt")
            accr_best = accr
Beispiel #2
0
criterion = torch.nn.CrossEntropyLoss().cuda()

teacher = nn.DataParallel(teacher)
generator = nn.DataParallel(generator)


def kdloss(y, teacher_scores):
    p = F.log_softmax(y, dim=1)
    q = F.softmax(teacher_scores, dim=1)
    l_kl = F.kl_div(p, q, size_average=False) / y.shape[0]
    return l_kl


if opt.dataset == 'MNIST':
    # Configure data loader
    net = LeNet5Half().cuda()
    net = nn.DataParallel(net)
    data_test = MNIST(opt.data,
                      train=False,
                      transform=transforms.Compose([
                          transforms.Resize((32, 32)),
                          transforms.ToTensor(),
                          transforms.Normalize((0.1307, ), (0.3081, ))
                      ]))
    data_test_loader = DataLoader(data_test,
                                  batch_size=64,
                                  num_workers=1,
                                  shuffle=False)

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_G)
def main(opt):
    """
    """
    print(f'image shape: {opt.channels} x {opt.img_size} x {opt.img_size}')

    if torch.cuda.device_count() == 0:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda')

    accr = 0
    accr_best = 0

    generator = Generator(opt).to(device)

    if opt.dataset == 'imagenet':
        assert opt.teacher_model_name != 'none', 'DAFL does not support imagene'
        teacher = eval(f'models.{opt.teacher_model_name}(pretrained = True)')
        teacher = teacher.to(device)
        # teacher.eval()
        assert opt.student_model_name != 'none', 'DAFL does not support imagenet'
        net = eval(f'models.{opt.student_model_name}(pretrained = False)')
        net = net.to(device)

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

        # for optimizing the teacher model
        if opt.train_teacher:
            data_train = torchvision.datasets.ImageNet(
                opt.data_dir, split='train', transform=transform_train)
            data_train_loader = DataLoader(data_train,
                                           batch_size=opt.batch_size,
                                           shuffle=True,
                                           num_workers=4,
                                           pin_memory=True)
            optimizer = torch.optim.Adam(teacher.parameters(), lr=0.001)

        # for optimizing the student model
        data_test = torchvision.datasets.ImageNet(opt.data_dir,
                                                  split='val',
                                                  transform=transform_test)
        data_test_loader = DataLoader(data_test,
                                      batch_size=opt.batch_size,
                                      num_workers=4,
                                      shuffle=False)
        optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_G)
        optimizer_S = torch.optim.SGD(net.parameters(),
                                      lr=opt.lr_S,
                                      momentum=0.9,
                                      weight_decay=5e-4)

    else:
        if opt.dataset == 'MNIST':
            # use the original DAFL network
            if opt.teacher_model_name == 'none':
                teacher = LeNet5().to(device)
            # use torchvision models
            else:
                teacher = eval(
                    f'models.{opt.teacher_model_name}(pretrained = False)')
                teacher.conv1 = nn.Conv2d(
                    1, teacher.conv1.out_channels, teacher.conv1.kernel_size,
                    teacher.conv1.stride, teacher.conv1.padding,
                    teacher.conv1.dilation, teacher.conv1.groups,
                    teacher.conv1.bias, teacher.conv1.padding_mode)
                teacher.fc = nn.Linear(teacher.fc.in_features, 10)
                teacher = teacher.to(device)

            # use the original DAFL network
            if opt.student_model_name == 'none':
                net = LeNet5Half().to(device)
            # use torchvision models
            else:
                net = eval(f'models.{opt.student_model_name}()')
                net.conv1 = nn.Conv2d(1, net.conv1.out_channels,
                                      net.conv1.kernel_size, net.conv1.stride,
                                      net.conv1.padding, net.conv1.dilation,
                                      net.conv1.groups, net.conv1.bias,
                                      net.conv1.padding_mode)
                net.fc = nn.Linear(net.fc.in_features, 10)
                net = net.to(device)

            # for optimizing the teacher model
            if opt.train_teacher:
                data_train = MNIST(opt.data_dir,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.Resize((32, 32)),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307, ),
                                                            (0.3081, ))
                                   ]))
                data_train_loader = DataLoader(data_train,
                                               batch_size=256,
                                               shuffle=True,
                                               num_workers=4)
                optimizer = torch.optim.Adam(teacher.parameters(), lr=0.001)

            # for optimizing the student model
            data_test = MNIST(opt.data_dir,
                              download=True,
                              train=False,
                              transform=transforms.Compose([
                                  transforms.Resize((32, 32)),
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.1307, ), (0.3081, ))
                              ]))
            data_test_loader = DataLoader(data_test,
                                          batch_size=64,
                                          num_workers=4,
                                          shuffle=False)
            optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_G)
            optimizer_S = torch.optim.Adam(net.parameters(), lr=opt.lr_S)

        elif opt.dataset == 'cifar10':
            # use the original DAFL network
            if opt.teacher_model_name == 'none':
                teacher = resnet.ResNet34().to(device)
            # use torchvision models
            else:
                teacher = eval(
                    f'models.{opt.teacher_model_name}(pretrained = False)')
                teacher.fc = nn.Linear(teacher.fc.in_features, 10)
                teacher = teacher.to(device)

            # use the original DAFL network
            if opt.student_model_name == 'none':
                net = resnet.ResNet18().to(device)

            # use torchvision models
            else:
                net = eval(f'models.{opt.student_model_name}()')
                net.fc = nn.Linear(net.fc.in_features, 10)
                net = net.to(device)

            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])

            # for optimizing the teacher model
            if opt.train_teacher:
                data_train = CIFAR10(opt.data_dir,
                                     download=True,
                                     transform=transform_train)
                data_train_loader = DataLoader(data_train,
                                               batch_size=128,
                                               shuffle=True,
                                               num_workers=4)
                optimizer = torch.optim.SGD(teacher.parameters(),
                                            lr=0.1,
                                            momentum=0.9,
                                            weight_decay=5e-4)

            # for optimizing the student model
            data_test = CIFAR10(opt.data_dir,
                                download=True,
                                train=False,
                                transform=transform_test)
            data_test_loader = DataLoader(data_test,
                                          batch_size=100,
                                          num_workers=4)
            optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_G)
            optimizer_S = torch.optim.SGD(net.parameters(),
                                          lr=opt.lr_S,
                                          momentum=0.9,
                                          weight_decay=5e-4)

        elif opt.dataset == 'cifar100':
            # use the original DAFL network
            if opt.teacher_model_name == 'none':
                teacher = resnet.ResNet34(num_classes=100).to(device)
            # use torchvision models
            else:
                teacher = eval(
                    f'models.{opt.teacher_model_name}(pretrained = False)')
                teacher.fc = nn.Linear(teacher.fc.in_features, 100)
                teacher = teacher.to(device)

            # use the original DAFL network
            if opt.student_model_name == 'none':
                net = resnet.ResNet18(num_classes=100).to(device)
            # use torchvision models
            else:
                net = eval(f'models.{opt.student_model_name}()')
                net.fc = nn.Linear(net.fc.in_features, 100)
                net = net.to(device)

            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])

            # for optimizing the teacher model
            if opt.train_teacher:
                data_train = CIFAR100(opt.data_dir,
                                      download=True,
                                      transform=transform_train)
                data_train_loader = DataLoader(data_train,
                                               batch_size=128,
                                               shuffle=True,
                                               num_workers=4)
                optimizer = torch.optim.SGD(teacher.parameters(),
                                            lr=0.1,
                                            momentum=0.9,
                                            weight_decay=5e-4)

            # for optimizing the student model
            data_test = CIFAR100(opt.data_dir,
                                 download=True,
                                 train=False,
                                 transform=transform_test)
            data_test_loader = DataLoader(data_test,
                                          batch_size=100,
                                          num_workers=4)
            optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_G)
            optimizer_S = torch.optim.SGD(net.parameters(),
                                          lr=opt.lr_S,
                                          momentum=0.9,
                                          weight_decay=5e-4)

    # train the teacher model on the specified dataset
    if opt.train_teacher:
        train_teacher(teacher, data_train_loader, data_test_loader, optimizer,
                      opt.n_epochs_teacher)

    if torch.cuda.device_count() > 1:
        teacher = nn.DataParallel(teacher)
        generator = nn.DataParallel(generator)
        net = nn.DataParallel(net)

    criterion = torch.nn.CrossEntropyLoss().cuda()
    if opt.pretest:
        test(teacher, data_test_loader)

    # ----------
    #  Training
    # ----------
    batches_done = 0
    for epoch in range(opt.n_epochs):
        total_correct = 0
        avg_loss = 0.0
        if opt.dataset != 'MNIST':
            adjust_learning_rate(optimizer_S, epoch, opt.lr_S)

        for i in range(120):
            net.train()
            z = torch.randn(opt.batch_size, opt.latent_dim).cuda()
            optimizer_G.zero_grad()
            optimizer_S.zero_grad()
            gen_imgs = generator(z)
            # teacher inference should not calculate gradients
            if opt.dataset != 'imagenet' and opt.teacher_model_name == 'none':
                outputs_T, features_T = teacher(gen_imgs, out_feature=True)
            else:
                features = [torch.Tensor().cuda(0)]

                def hook_features(model, input, output):
                    features[0] = torch.cat((features[0], output.cuda(0)), 0)

                if torch.cuda.device_count() > 1:
                    teacher.module.avgpool.register_forward_hook(hook_features)
                else:
                    teacher.avgpool.register_forward_hook(hook_features)
                outputs_T = teacher(gen_imgs)
                features_T = features[0]

            pred = outputs_T.data.max(1)[1]
            loss_activation = -features_T.abs().mean()
            loss_one_hot = criterion(outputs_T, pred)
            softmax_o_T = torch.nn.functional.softmax(outputs_T,
                                                      dim=1).mean(dim=0)
            loss_information_entropy = (softmax_o_T *
                                        torch.log10(softmax_o_T)).sum()
            loss = (loss_one_hot * opt.oh + loss_information_entropy * opt.ie +
                    loss_activation * opt.a)

            loss_kd = kdloss(net(gen_imgs.detach()), outputs_T.detach())

            loss += loss_kd

            loss.backward()
            optimizer_G.step()
            optimizer_S.step()
            if i == 1:
                print( f'[Epoch {epoch}/{opt.n_epochs}]'\
                         '[loss_oh: {loss_one_hot.item()}]'\
                         '[loss_ie: {loss_information_entropy.item()}]'\
                         '[loss_a: {loss_activation.item()}]'\
                         '[loss_kd: {loss_kd.item()}]' )

        test(net, data_test_loader)