コード例 #1
0
def main():
    global args, best_err1, best_err5
    args = parser.parse_args()

    traindir = os.path.join(args.data_path, 'train')
    valdir = os.path.join(args.data_path, 'val')
        
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_trm = [transforms.RandomResizedCrop(args.input_res),
                 transforms.RandomHorizontalFlip()]
    if args.autoaug:
        train_trm.append(AutoAugment())
    if args.cutout:
        train_trm.append(Cutout(length=args.input_res // 4))
    train_trm.append(transforms.ToTensor())
    train_trm.append(normalize)
    train_trm = transforms.Compose(train_trm)

    train_dataset = datasets.ImageFolder(traindir, transform=train_trm)

    val_dataset = datasets.ImageFolder(valdir,
                                       transforms.Compose([
                                           transforms.Resize(int(args.input_res / 0.875)),
                                           transforms.CenterCrop(args.input_res),
                                           transforms.ToTensor(),
                                           normalize])
                                       )

    num_classes = len(train_dataset.classes)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.workers, pin_memory=True, sampler=None)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
                                             num_workers=args.workers, pin_memory=True)
                                             
    if args.net_type == 'self_sup':
        t_net = ResNet.resnet50(pretrained=False, num_classes=num_classes)
        s_net = ResNet.resnet50(pretrained=False, num_classes=num_classes)
        # load from pre-trained, before DistributedDataParallel constructor
        if args.pretrained:
            if os.path.isfile(args.pretrained):
                print("=> loading checkpoint '{}'".format(args.pretrained))
                checkpoint = torch.load(args.pretrained, map_location="cpu")

                # rename moco pre-trained keys
                state_dict = checkpoint['state_dict']
                for k in list(state_dict.keys()):
                    # retain only encoder_q up to before the embedding layer
                    if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
                        # remove prefix
                        state_dict[k[len("module.encoder_q."):]] = state_dict[k]
                    # delete renamed or unused k
                    del state_dict[k]

                args.start_epoch = 0
                msg = t_net.load_state_dict(state_dict, strict=False)
                assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
                msg = s_net.load_state_dict(state_dict, strict=False)
                assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}

                print("=> loaded pre-trained model '{}'".format(args.pretrained))
            else:
                print("=> no checkpoint found at '{}'".format(args.pretrained))
            
    elif args.net_type == 'infomin_self':
        from collections import OrderedDict
        t_net = ResNet.resnet50(pretrained=False, num_classes=num_classes)
        s_net = ResNet.resnet50(pretrained=False, num_classes=num_classes)
        if args.pretrained:
            ckpt = torch.load(args.pretrained, map_location='cpu')
            state_dict = ckpt['model']
            encoder_state_dict = OrderedDict()
            for k, v in state_dict.items():
                k = k.replace('module.', '')
                if 'encoder' in k:
                    k = k.replace('encoder.', '')
                    encoder_state_dict[k] = v

            msg = t_net.load_state_dict(encoder_state_dict, strict=False)
            print(set(msg.missing_keys))
            msg = s_net.load_state_dict(encoder_state_dict, strict=False)
            print(set(msg.missing_keys))
    elif args.net_type == 'r50':
        s_net = ResNet.resnet50(pretrained=False, num_classes=num_classes)
    else:
        print('undefined network type !!!')
        raise

    if not args.no_tea:
        d_net = distiller.Distiller(t_net, s_net)

    if not args.no_tea:
        print('the number of teacher model parameters: {}'.format(sum([p.data.nelement() for p in t_net.parameters()])))
    print('the number of student model parameters: {}'.format(sum([p.data.nelement() for p in s_net.parameters()])))

    if not args.no_tea:
        t_net = torch.nn.DataParallel(t_net).cuda()
        s_net = torch.nn.DataParallel(s_net).cuda()
        d_net = torch.nn.DataParallel(d_net).cuda()
    else:
        s_net = torch.nn.DataParallel(s_net).cuda()

    # define loss function (criterion) and optimizer
    if not args.label_smooth:
        criterion_CE = nn.CrossEntropyLoss().cuda()
    else:
        criterion_CE = CrossEntropyLabelSmooth(num_classes=num_classes)
    if not args.no_tea:
        optimizer = torch.optim.SGD(list(s_net.parameters()) + list(d_net.module.Connectors.parameters()), args.lr,
                                    momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
    else:
        optimizer = torch.optim.SGD(list(s_net.parameters()), args.lr,
                                    momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
    cudnn.benchmark = True

    for epoch in range(1, args.epochs+1):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        if not args.no_tea:
            if args.mixup:
                train_with_distill_mixup(train_loader, d_net, optimizer, criterion_CE, epoch, args.no_distill_epoch)
            else:
                train_with_distill(train_loader, d_net, optimizer, criterion_CE, epoch, args.no_distill_epoch)
        else:
            if args.mixup:
                train_mixup(train_loader, s_net, optimizer, criterion_CE, epoch)
            else:
                train(train_loader, s_net, optimizer, criterion_CE, epoch)
            
        # evaluate on validation set
        err1, err5 = validate(val_loader, s_net, criterion_CE, epoch)

        # remember best prec@1 and save checkpoint
        is_best = err1 <= best_err1
        best_err1 = min(err1, best_err1)
        if is_best:
            best_err5 = err5
        print ('Current best accuracy (top-1 and 5 error):', best_err1, best_err5)
        save_checkpoint({
            'epoch': epoch,
            'arch': args.net_type,
            'state_dict': s_net.state_dict(),
            'best_err1': best_err1,
            'best_err5': best_err5,
            'optimizer' : optimizer.state_dict(),
        }, is_best)
        gc.collect()

    print ('Best accuracy (top-1 and 5 error):', best_err1, best_err5)
コード例 #2
0
def load_model(config, num_classes):
    if config['model']['type'] == 'resnet':
        if config['model']['arch'] == 'resnet50':
            net = ResNet.resnet50(pretrained=False,
                                  progress=False,
                                  num_classes=num_classes)
        elif config['model']['arch'] == 'resnext50':
            net = ResNet.resnext50_32x4d(pretrained=False,
                                         progress=False,
                                         num_classes=num_classes)
        elif config['model']['arch'] == 'resnet50d':
            net = ResNetD.resnet50d(pretrained=False,
                                    progress=False,
                                    num_classes=num_classes)
        else:
            raise ValueError('Unsupported architecture: ' +
                             str(config['model']['arch']))
    elif config['model']['type'] == 'tresnet':
        if config['model']['arch'] == 'tresnetm':
            net = TResNet.TResnetM(num_classes=num_classes)
        elif config['model']['arch'] == 'tresnetl':
            net = TResNet.TResnetL(num_classes=num_classes)
        elif config['model']['arch'] == 'tresnetxl':
            net = TResNet.TResnetXL(num_classes=num_classes)
        else:
            raise ValueError('Unsupported architecture: ' +
                             str(config['model']['arch']))
    elif config['model']['type'] == 'regnet':
        regnet_config = dict()
        if config['model']['arch'] == 'regnetx-200mf':
            regnet_config['depth'] = 13
            regnet_config['w0'] = 24
            regnet_config['wa'] = 36.44
            regnet_config['wm'] = 2.49
            regnet_config['group_w'] = 8
            regnet_config['se_on'] = False
            regnet_config['num_classes'] = num_classes

            net = RegNet.RegNet(regnet_config)
        elif config['model']['arch'] == 'regnetx-600mf':
            regnet_config['depth'] = 16
            regnet_config['w0'] = 48
            regnet_config['wa'] = 36.97
            regnet_config['wm'] = 2.24
            regnet_config['group_w'] = 24
            regnet_config['se_on'] = False
            regnet_config['num_classes'] = num_classes

            net = RegNet.RegNet(regnet_config)
        elif config['model']['arch'] == 'regnetx-4.0gf':
            regnet_config['depth'] = 23
            regnet_config['w0'] = 96
            regnet_config['wa'] = 38.65
            regnet_config['wm'] = 2.43
            regnet_config['group_w'] = 40
            regnet_config['se_on'] = False
            regnet_config['num_classes'] = num_classes

            net = RegNet.RegNet(regnet_config)
        elif config['model']['arch'] == 'regnetx-6.4gf':
            regnet_config['depth'] = 17
            regnet_config['w0'] = 184
            regnet_config['wa'] = 60.83
            regnet_config['wm'] = 2.07
            regnet_config['group_w'] = 56
            regnet_config['se_on'] = False
            regnet_config['num_classes'] = num_classes

            net = RegNet.RegNet(regnet_config)
        elif config['model']['arch'] == 'regnety-200mf':
            regnet_config['depth'] = 13
            regnet_config['w0'] = 24
            regnet_config['wa'] = 36.44
            regnet_config['wm'] = 2.49
            regnet_config['group_w'] = 8
            regnet_config['se_on'] = True
            regnet_config['num_classes'] = num_classes

            net = RegNet.RegNet(regnet_config)
        elif config['model']['arch'] == 'regnety-600mf':
            regnet_config['depth'] = 15
            regnet_config['w0'] = 48
            regnet_config['wa'] = 32.54
            regnet_config['wm'] = 2.32
            regnet_config['group_w'] = 16
            regnet_config['se_on'] = True
            regnet_config['num_classes'] = num_classes

            net = RegNet.RegNet(regnet_config)
        elif config['model']['arch'] == 'regnety-4.0gf':
            regnet_config['depth'] = 22
            regnet_config['w0'] = 96
            regnet_config['wa'] = 31.41
            regnet_config['wm'] = 2.24
            regnet_config['group_w'] = 64
            regnet_config['se_on'] = True
            regnet_config['num_classes'] = num_classes

            net = RegNet.RegNet(regnet_config)
        elif config['model']['arch'] == 'regnety-6.4gf':
            regnet_config['depth'] = 25
            regnet_config['w0'] = 112
            regnet_config['wa'] = 33.22
            regnet_config['wm'] = 2.27
            regnet_config['group_w'] = 72
            regnet_config['se_on'] = True
            regnet_config['num_classes'] = num_classes

            net = RegNet.RegNet(regnet_config)
        else:
            raise ValueError('Unsupported architecture: ' +
                             str(config['model']['arch']))
    elif config['model']['type'] == 'resnest':
        if config['model']['arch'] == 'resnest50':
            net = ResNest.resnest50(pretrained=False, num_classes=num_classes)
        elif config['model']['arch'] == 'resnest101':
            net = ResNest.resnest101(pretrained=False, num_classes=num_classes)
        else:
            raise ValueError('Unsupported architecture: ' +
                             str(config['model']['arch']))
    elif config['model']['type'] == 'efficient':
        if config['model']['arch'] == 'b0':
            net = EfficientNet.efficientnet_b0(pretrained=False,
                                               num_classes=num_classes)
        elif config['model']['arch'] == 'b1':
            net = EfficientNet.efficientnet_b1(pretrained=False,
                                               num_classes=num_classes)
        elif config['model']['arch'] == 'b2':
            net = EfficientNet.efficientnet_b2(pretrained=False,
                                               num_classes=num_classes)
        elif config['model']['arch'] == 'b3':
            net = EfficientNet.efficientnet_b3(pretrained=False,
                                               num_classes=num_classes)
        elif config['model']['arch'] == 'b4':
            net = EfficientNet.efficientnet_b4(pretrained=False,
                                               num_classes=num_classes)
        elif config['model']['arch'] == 'b5':
            net = EfficientNet.efficientnet_b5(pretrained=False,
                                               num_classes=num_classes)
        elif config['model']['arch'] == 'b6':
            net = EfficientNet.efficientnet_b6(pretrained=False,
                                               num_classes=num_classes)
        else:
            raise ValueError('Unsupported architecture: ' +
                             str(config['model']['arch']))
    elif config['model']['type'] == 'assembled':
        pass
    elif config['model']['type'] == 'shufflenet':
        if config['model']['arch'] == 'v2_x0_5':
            net = ShuffleNetV2.shufflenet_v2_x0_5(pretrained=False,
                                                  progress=False,
                                                  num_classes=num_classes)
        elif config['model']['arch'] == 'v2_x1_0':
            net = ShuffleNetV2.shufflenet_v2_x1_0(pretrained=False,
                                                  progress=False,
                                                  num_classes=num_classes)
        elif config['model']['arch'] == 'v2_x1_5':
            net = ShuffleNetV2.shufflenet_v2_x1_5(pretrained=False,
                                                  progress=False,
                                                  num_classes=num_classes)
        elif config['model']['arch'] == 'v2_x2_0':
            net = ShuffleNetV2.shufflenet_v2_x2_0(pretrained=False,
                                                  progress=False,
                                                  num_classes=num_classes)
        else:
            raise ValueError('Unsupported architecture: ' +
                             str(config['model']['arch']))
    elif config['model']['type'] == 'mobilenet':
        if config['model']['arch'] == 'small_075':
            net = Mobilenetv3.mobilenetv3_small_075(pretrained=False,
                                                    num_classes=num_classes)
        elif config['model']['arch'] == 'small_100':
            net = Mobilenetv3.mobilenetv3_small_100(pretrained=False,
                                                    num_classes=num_classes)
        elif config['model']['arch'] == 'large_075':
            net = Mobilenetv3.mobilenetv3_large_075(pretrained=False,
                                                    num_classes=num_classes)
        elif config['model']['arch'] == 'large_100':
            net = Mobilenetv3.mobilenetv3_large_100(pretrained=False,
                                                    num_classes=num_classes)
        else:
            raise ValueError('Unsupported architecture: ' +
                             str(config['model']['arch']))
    elif config['model']['type'] == 'rexnet':
        if config['model']['arch'] == 'rexnet1.0x':
            net = ReXNet.rexnet(num_classes=num_classes, width_multi=1.0)
        elif config['model']['arch'] == 'rexnet1.5x':
            net = ReXNet.rexnet(num_classes=num_classes, width_multi=1.5)
        elif config['model']['arch'] == 'rexnet2.0x':
            net = ReXNet.rexnet(num_classes=num_classes, width_multi=2.0)
        else:
            raise ValueError('Unsupported architecture: ' +
                             str(config['model']['arch']))
    else:
        raise ValueError('Unsupported architecture: ' +
                         str(config['model']['type']))

    return net
コード例 #3
0
def main():
    global args, best_err1, best_err5
    args = parser.parse_args()

    traindir = os.path.join(args.data_path, 'train')
    valdir = os.path.join(args.data_path, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(traindir,
                                         transforms.Compose([
                                             transforms.RandomResizedCrop(224),
                                             transforms.RandomHorizontalFlip(),
                                             transforms.ToTensor(),
                                             normalize])
                                         )
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.workers, pin_memory=True, sampler=None)
    val_dataset = datasets.ImageFolder(valdir,
                                       transforms.Compose([
                                           transforms.Resize(256),
                                           transforms.CenterCrop(224),
                                           transforms.ToTensor(),
                                           normalize])
                                       )
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
                                             num_workers=args.workers, pin_memory=True)

    if args.net_type == 'mobilenet':
        t_net = ResNet.resnet50(pretrained=True)
        s_net = Mov.MobileNet()
    elif args.net_type == 'resnet':
        t_net = ResNet.resnet152(pretrained=True)
        s_net = ResNet.resnet50(pretrained=False)
    else:
        print('undefined network type !!!')
        raise

    d_net = distiller.Distiller(t_net, s_net)

    print ('Teacher Net: ')
    print(t_net)
    print ('Student Net: ')
    print(s_net)
    print('the number of teacher model parameters: {}'.format(sum([p.data.nelement() for p in t_net.parameters()])))
    print('the number of student model parameters: {}'.format(sum([p.data.nelement() for p in s_net.parameters()])))

    t_net = torch.nn.DataParallel(t_net).cuda()
    s_net = torch.nn.DataParallel(s_net).cuda()
    d_net = torch.nn.DataParallel(d_net).cuda()

    # define loss function (criterion) and optimizer
    criterion_CE = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(list(s_net.parameters()) + list(d_net.module.Connectors.parameters()), args.lr,
                                momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
    cudnn.benchmark = True

    print('Teacher network performance')
    validate(val_loader, t_net, criterion_CE, 0)

    for epoch in range(1, args.epochs+1):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train_with_distill(train_loader, d_net, optimizer, criterion_CE, epoch)
        # evaluate on validation set
        err1, err5 = validate(val_loader, s_net, criterion_CE, epoch)

        # remember best prec@1 and save checkpoint
        is_best = err1 <= best_err1
        best_err1 = min(err1, best_err1)
        if is_best:
            best_err5 = err5
        print ('Current best accuracy (top-1 and 5 error):', best_err1, best_err5)
        save_checkpoint({
            'epoch': epoch,
            'arch': args.net_type,
            'state_dict': s_net.state_dict(),
            'best_err1': best_err1,
            'best_err5': best_err5,
            'optimizer' : optimizer.state_dict(),
        }, is_best)
        gc.collect()

    print ('Best accuracy (top-1 and 5 error):', best_err1, best_err5)
コード例 #4
0
def main():
    global best_err1, best_err5

    if config.train_params.use_seed:
        utils.set_seed(config.train_params.seed)

    imagenet = imagenet_data.ImageNet12(trainFolder=os.path.join(config.data.data_path, 'train'),
                                        testFolder=os.path.join(config.data.data_path, 'val'),
                                        num_workers=config.data.num_workers,
                                        type_of_data_augmentation=config.data.type_of_data_aug,
                                        data_config=config.data)

    train_loader, val_loader = imagenet.getTrainTestLoader(config.data.batch_size)

    if config.net_type == 'mobilenet':
        t_net = ResNet.resnet50(pretrained=True)
        s_net = Mov.MobileNet()
    elif config.net_type == 'resnet':
        t_net = ResNet.resnet34(pretrained=True)
        s_net = ResNet.resnet18(pretrained=False)
    else:
        print('undefined network type !!!')
        raise RuntimeError('%s does not support' % config.net_type)

    import knowledge_distiller
    d_net = knowledge_distiller.WSLDistiller(t_net, s_net)

    print('Teacher Net: ')
    print(t_net)
    print('Student Net: ')
    print(s_net)
    print('the number of teacher model parameters: {}'.format(sum([p.data.nelement() for p in t_net.parameters()])))
    print('the number of student model parameters: {}'.format(sum([p.data.nelement() for p in s_net.parameters()])))

    t_net = torch.nn.DataParallel(t_net)
    s_net = torch.nn.DataParallel(s_net)
    d_net = torch.nn.DataParallel(d_net)

    if config.optim.if_resume:
        checkpoint = torch.load(config.optim.resume_path)
        d_net.module.load_state_dict(checkpoint['train_state_dict'])
        best_err1 = checkpoint['best_err1']
        best_err5 = checkpoint['best_err5']
        start_epoch = checkpoint['epoch'] + 1
    else:
        start_epoch = 0


    t_net = t_net.cuda()
    s_net = s_net.cuda()
    d_net = d_net.cuda()

    ### choose optimizer parameters

    optimizer = torch.optim.SGD(list(s_net.parameters()), config.optim.init_lr,
                                momentum=config.optim.momentum, weight_decay=config.optim.weight_decay, nesterov=True)

    cudnn.benchmark = True
    cudnn.enabled = True

    print('Teacher network performance')
    validate(val_loader, t_net, 0)

    for epoch in range(start_epoch, config.train_params.epochs + 1):

        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train_with_distill(train_loader, d_net, optimizer, epoch)

        # evaluate on validation set
        err1, err5 = validate(val_loader, s_net, epoch)

        # remember best prec@1 and save checkpoint
        is_best = err1 <= best_err1
        best_err1 = min(err1, best_err1)
        if is_best:
            best_err5 = err5
        print('Current best accuracy (top-1 and 5 error):', best_err1, best_err5)
        save_checkpoint({
            'epoch': epoch,
            'state_dict': s_net.module.state_dict(),
            'train_state_dict': d_net.module.state_dict(),
            'best_err1': best_err1,
            'best_err5': best_err5,
            'optimizer': optimizer.state_dict(),
        }, is_best)
        gc.collect()

    print('Best accuracy (top-1 and 5 error):', best_err1, best_err5)
コード例 #5
0
import models.ResNet as resnet

print("running image analyzer")

dualnet = resnet.resnet50(pretrained=True)

print("end image analyzer")