Ejemplo n.º 1
0
def get_criterion(dataset, teacher, num_classes, gpu, mixup, label_smoothing):
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    criterion_kd = None
    if mixup > 0.0:
        criterion_kd = NLLMultiLabelSmooth(label_smoothing)
    elif label_smoothing > 0.0:
        criterion_kd = LabelSmoothing(label_smoothing)
    if criterion_kd:
        criterion_kd = criterion_kd.cuda()
        return criterion_kd, criterion, None

    # criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
    # criterion_smooth = criterion_smooth.cuda()
    if teacher == 'none':
        return criterion, criterion, None

    if dataset == 'imagenet':
        # load model
        model_teacher = models.__dict__[teacher](pretrained=True)
        model_teacher = model_teacher.cuda()
        model_teacher = DDP(model_teacher, device_ids=[gpu])
        for p in model_teacher.parameters():
            p.requires_grad = False
        model_teacher.eval()

        criterion_kd = KD_loss.DistributionLoss()
        return criterion_kd, criterion, model_teacher
    else:
        model_teacher = get_model(teacher,
                                  num_classes=num_classes,
                                  num_channels=96)
        model_teacher = model_teacher.cuda()
        if args.distributed:
            model_teacher = DDP(model_teacher, device_ids=[gpu])

        checkpoint_tar = 'teacher.pth.tar'
        print('loading checkpoint {} ..........'.format(checkpoint_tar))
        checkpoint = torch.load(
            checkpoint_tar,
            map_location=lambda storage, loc: storage.cuda(args.gpu))
        model_teacher.load_state_dict(checkpoint['state_dict'], strict=False)
        print("loaded checkpoint {}".format(checkpoint_tar))

        for p in model_teacher.parameters():
            p.requires_grad = False
        model_teacher.eval()

        criterion_kd = KD_loss.DistributionLoss()
        return criterion_kd, criterion, model_teacher
Ejemplo n.º 2
0
def main():
    if not torch.cuda.is_available():
        sys.exit(1)
    start_t = time.time()

    cudnn.benchmark = True
    cudnn.enabled = True
    logging.info("args = %s", args)

    # load model
    model_teacher = models.__dict__[args.teacher](pretrained=True)
    model_teacher = nn.DataParallel(model_teacher).cuda()
    for p in model_teacher.parameters():
        p.requires_grad = False
    model_teacher.eval()

    model_student = birealnet18()
    logging.info('student:')
    logging.info(model_student)
    model_student = nn.DataParallel(model_student).cuda()

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
    criterion_smooth = criterion_smooth.cuda()
    criterion_kd = KD_loss.DistributionLoss()

    all_parameters = model_student.parameters()
    weight_parameters = []
    for pname, p in model_student.named_parameters():
        if p.ndimension() == 4 or 'conv' in pname:
            weight_parameters.append(p)
    weight_parameters_id = list(map(id, weight_parameters))
    other_parameters = list(
        filter(lambda p: id(p) not in weight_parameters_id, all_parameters))

    optimizer = torch.optim.Adam(
        [{
            'params': other_parameters
        }, {
            'params': weight_parameters,
            'weight_decay': args.weight_decay
        }],
        lr=args.learning_rate,
    )

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                  lambda step:
                                                  (1.0 - step / args.epochs),
                                                  last_epoch=-1)
    start_epoch = 0
    best_top1_acc = 0

    # checkpoint_tar = os.path.join(args.save, 'checkpoint_ba.pth.tar')
    # checkpoint = torch.load(checkpoint_tar)
    # model_student.load_state_dict(checkpoint['state_dict'], strict=False)

    checkpoint_tar = os.path.join(args.save, 'checkpoint.pth.tar')
    if os.path.exists(checkpoint_tar):
        logging.info('loading checkpoint {} ..........'.format(checkpoint_tar))
        checkpoint = torch.load(checkpoint_tar)
        start_epoch = checkpoint['epoch']
        best_top1_acc = checkpoint['best_top1_acc']
        model_student.load_state_dict(checkpoint['state_dict'], strict=False)
        logging.info("loaded checkpoint {} epoch = {}".format(
            checkpoint_tar, checkpoint['epoch']))

    # adjust the learning rate according to the checkpoint
    for epoch in range(start_epoch):
        scheduler.step()

    # load training data
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # data augmentation
    crop_scale = 0.08
    lighting_param = 0.1
    train_transforms = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(crop_scale, 1.0)),
        Lighting(lighting_param),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ])

    train_dataset = datasets.ImageFolder(traindir, transform=train_transforms)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    # load validation data
    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # train the model
    valid_obj, valid_top1_acc, valid_top5_acc = validate(
        epoch, val_loader, model_student, criterion, args)
Ejemplo n.º 3
0
def main():
    if not torch.cuda.is_available():
        sys.exit(1)
    start_t = time.time()

    cudnn.benchmark = True
    cudnn.enabled = True
    logging.info("args = %s", args)

    # load model
    #model_teacher = models.__dict__[args.teacher](pretrained=True)
    model_teacher = cifar_resnet56(pretrained="cifar100")
    model_teacher = nn.DataParallel(model_teacher).cuda()
    for p in model_teacher.parameters():
        p.requires_grad = False
    model_teacher.eval()

    model_student = reactnet(100)
    logging.info('student:')
    logging.info(model_student)
    model_student = nn.DataParallel(model_student).cuda()

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
    criterion_smooth = criterion_smooth.cuda()
    criterion_kd = KD_loss.DistributionLoss()

    all_parameters = model_student.parameters()
    weight_parameters = []
    for pname, p in model_student.named_parameters():
        if p.ndimension() == 4 or 'conv' in pname:
            weight_parameters.append(p)
    weight_parameters_id = list(map(id, weight_parameters))
    other_parameters = list(
        filter(lambda p: id(p) not in weight_parameters_id, all_parameters))

    optimizer = torch.optim.Adam(
        [{
            'params': other_parameters
        }, {
            'params': weight_parameters,
            'weight_decay': args.weight_decay
        }],
        lr=args.learning_rate,
    )

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                  lambda step:
                                                  (1.0 - step / args.epochs),
                                                  last_epoch=-1)
    start_epoch = 0
    best_top1_acc = 0

    #checkpoint_tar = os.path.join(args.save, 'checkpoint_ba.pth.tar')
    #checkpoint = torch.load(checkpoint_tar)
    #model_student.load_state_dict(checkpoint['state_dict'], strict=False)

    checkpoint_tar = os.path.join(args.save, 'checkpoint.pth.tar')
    if os.path.exists(checkpoint_tar):
        logging.info('loading checkpoint {} ..........'.format(checkpoint_tar))
        checkpoint = torch.load(checkpoint_tar)
        start_epoch = checkpoint['epoch'] + 1
        best_top1_acc = checkpoint['best_top1_acc']
        model_student.load_state_dict(checkpoint['state_dict'], strict=False)
        logging.info("loaded checkpoint {} epoch = {}".format(
            checkpoint_tar, checkpoint['epoch']))

    # adjust the learning rate according to the checkpoint
    for epoch in range(start_epoch):
        scheduler.step()

    # load training data

# traindir = os.path.join(args.data, 'train')
#valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # data augmentation
    crop_scale = 0.08
    lighting_param = 0.1
    train_transforms = transforms.Compose([
        transforms.RandomResizedCrop(32, scale=(crop_scale, 1.0)),
        Lighting(lighting_param),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ])

    cifar100_train = datasets.CIFAR100("../data",
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.RandomCrop(size=32,
                                                                 padding=4),
                                           transforms.RandomHorizontalFlip(),
                                           transforms.ToTensor(),
                                           transforms.Normalize(
                                               mean=[0.5071, 0.4865, 0.4409],
                                               std=[0.2009, 0.1984, 0.2023],
                                           ),
                                       ]))
    cifar100_eval = datasets.CIFAR100("../data",
                                      train=False,
                                      download=False,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize(
                                              mean=[0.5071, 0.4865, 0.4409],
                                              std=[0.2009, 0.1984, 0.2023],
                                          ),
                                      ]))
    train_loader = DataLoader(
        cifar100_train,
        batch_size=64,
        shuffle=True,
    )
    val_loader = DataLoader(
        cifar100_eval,
        batch_size=64,
        shuffle=True,
    )

    # train the model
    epoch = start_epoch
    while epoch < args.epochs:
        train_obj, train_top1_acc, train_top5_acc = train(
            epoch, train_loader, model_student, model_teacher, criterion,
            optimizer, scheduler)
        valid_obj, valid_top1_acc, valid_top5_acc = validate(
            epoch, val_loader, model_student, criterion, args)

        is_best = False
        if valid_top1_acc > best_top1_acc:
            best_top1_acc = valid_top1_acc
            is_best = True

        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model_student.state_dict(),
                'best_top1_acc': best_top1_acc,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save)

        epoch += 1

    training_time = (time.time() - start_t) / 36000
    valid_obj, valid_top1_acc, valid_top5_acc = validate(
        epoch, val_loader, model_student, criterion, args)
    print('total training time = {} hours'.format(training_time))
def main():
    if not torch.cuda.is_available():
        sys.exit(1)
    start_t = time.time()

    cudnn.benchmark = True
    cudnn.enabled=True
    logging.info("args = %s", args)

    # load model for imagenet
    #model_teacher = models.__dict__[args.teacher](pretrained=True)
    #model_teacher = nn.DataParallel(model_teacher).cuda()

    #load teacher model for cifar-10
    model_teacher = resnet.resnet56()
    #print(model_teacher)
    checkpoint = torch.load("../../models/resnet56-4bfd9763.th")
    state_dict = checkpoint['state_dict']
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    model_teacher.load_state_dict(new_state_dict)
    model_teacher = nn.DataParallel(model_teacher).cuda()


    for p in model_teacher.parameters():
        p.requires_grad = False
    model_teacher.eval()

    model_student = reactnet(num_classes = CLASSES)
    #model_student = resnet.resnet32()
    logging.info('student:')
    logging.info(model_student)
    model_student = nn.DataParallel(model_student).cuda()

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
    criterion_smooth = criterion_smooth.cuda()
    criterion_kd = KD_loss.DistributionLoss()

    all_parameters = model_student.parameters()
    weight_parameters = []
    for pname, p in model_student.named_parameters():
        if p.ndimension() == 4 or 'conv' in pname:
            weight_parameters.append(p)
    weight_parameters_id = list(map(id, weight_parameters))
    other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters))

    optimizer = torch.optim.Adam(
            [{'params' : other_parameters},
            {'params' : weight_parameters, 'weight_decay' : args.weight_decay}],
            lr=args.learning_rate,)
    #optimizer = torch.optim.SGD(model_student.parameters(), lr=args.learning_rate,
    #                            momentum=args.momentum,
    #                            weight_decay=args.weight_decay)
    #scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/args.epochs), last_epoch=-1)
    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 45], gamma=0.1)
    start_epoch = 0
    best_top1_acc= 0

    checkpoint_tar = os.path.join(args.save, 'checkpoint.pth.tar')
    if os.path.exists(checkpoint_tar):
        logging.info('loading checkpoint {} ..........'.format(checkpoint_tar))
        checkpoint = torch.load(checkpoint_tar)
        start_epoch = checkpoint['epoch'] + 1
        best_top1_acc = checkpoint['best_top1_acc']
        model_student.load_state_dict(checkpoint['state_dict'], strict=False)
        logging.info("loaded checkpoint {} epoch = {}" .format(checkpoint_tar, checkpoint['epoch']))

    # adjust the learning rate according to the checkpoint
    for epoch in range(start_epoch):
        scheduler.step()

    # load cifar-10 dataset
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    train_dataset = datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True)

    val_dataset = datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True)

    classe_name = ('plane', 'car', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck')

    # train the model
    epoch = start_epoch
    while epoch < args.epochs:
        #train_obj, train_top1_acc,  train_top5_acc = train(epoch,  train_loader, model_student, model_teacher, criterion, optimizer, scheduler)
        #valid_obj, valid_top1_acc, valid_top5_acc = validate(epoch, val_loader, model_student, criterion, args)
        train_obj, train_top1_acc= train(epoch, train_loader, model_student, model_teacher, criterion,
                                                          optimizer, scheduler)
        valid_obj, valid_top1_acc= validate(epoch, val_loader, model_student, criterion, args)

        is_best = False
        if valid_top1_acc > best_top1_acc:
            best_top1_acc = valid_top1_acc
            is_best = True

        save_checkpoint({
            'epoch': epoch,
            'state_dict': model_student.state_dict(),
            'best_top1_acc': best_top1_acc,
            'optimizer' : optimizer.state_dict(),
            }, is_best, args.save)

        epoch += 1

    training_time = (time.time() - start_t) / 36000
    print('total training time = {} hours'.format(training_time))