gpu_id = int(config['gpu']['ind'])

    checkpoint_callback = ModelCheckpoint(filepath=os.path.join(config['exp']['path'], '\\'),
                                          monitor='avg_val_loss',
                                          verbose=False,
                                          save_last=True,
                                          save_top_k=1,
                                          mode='min',
                                          save_weights_only=False,
                                          period=1)

    logger = TensorBoardLogger(save_dir=os.path.dirname(config['exp']['path']),
                               name=os.path.basename(config['exp']['path']),
                               version='log')

    net = ResNet.resnet18(pretrained=True, progress=True, num_classes=100)
    model = ClassificationPL(net=net, configs=config)

    trainer = pl.Trainer(fast_dev_run=False,
                         max_epochs=2,
                         precision=32,
                         check_val_every_n_epoch=1,
                         distributed_backend=None,
                         benchmark=True,
                         gpus=gpu_id,
                         limit_test_batches=1.0,
                         limit_val_batches=1.0,
                         log_save_interval=1,
                         row_log_interval=1,
                         logger=logger,
                         checkpoint_callback=checkpoint_callback
예제 #2
0
def main():
    global best_err1, best_err5

    if config.train_params.use_seed:
        utils.set_seed(config.train_params.seed)

    imagenet = imagenet_data.ImageNet12(trainFolder=os.path.join(config.data.data_path, 'train'),
                                        testFolder=os.path.join(config.data.data_path, 'val'),
                                        num_workers=config.data.num_workers,
                                        type_of_data_augmentation=config.data.type_of_data_aug,
                                        data_config=config.data)

    train_loader, val_loader = imagenet.getTrainTestLoader(config.data.batch_size)

    if config.net_type == 'mobilenet':
        t_net = ResNet.resnet50(pretrained=True)
        s_net = Mov.MobileNet()
    elif config.net_type == 'resnet':
        t_net = ResNet.resnet34(pretrained=True)
        s_net = ResNet.resnet18(pretrained=False)
    else:
        print('undefined network type !!!')
        raise RuntimeError('%s does not support' % config.net_type)

    import knowledge_distiller
    d_net = knowledge_distiller.WSLDistiller(t_net, s_net)

    print('Teacher Net: ')
    print(t_net)
    print('Student Net: ')
    print(s_net)
    print('the number of teacher model parameters: {}'.format(sum([p.data.nelement() for p in t_net.parameters()])))
    print('the number of student model parameters: {}'.format(sum([p.data.nelement() for p in s_net.parameters()])))

    t_net = torch.nn.DataParallel(t_net)
    s_net = torch.nn.DataParallel(s_net)
    d_net = torch.nn.DataParallel(d_net)

    if config.optim.if_resume:
        checkpoint = torch.load(config.optim.resume_path)
        d_net.module.load_state_dict(checkpoint['train_state_dict'])
        best_err1 = checkpoint['best_err1']
        best_err5 = checkpoint['best_err5']
        start_epoch = checkpoint['epoch'] + 1
    else:
        start_epoch = 0


    t_net = t_net.cuda()
    s_net = s_net.cuda()
    d_net = d_net.cuda()

    ### choose optimizer parameters

    optimizer = torch.optim.SGD(list(s_net.parameters()), config.optim.init_lr,
                                momentum=config.optim.momentum, weight_decay=config.optim.weight_decay, nesterov=True)

    cudnn.benchmark = True
    cudnn.enabled = True

    print('Teacher network performance')
    validate(val_loader, t_net, 0)

    for epoch in range(start_epoch, config.train_params.epochs + 1):

        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train_with_distill(train_loader, d_net, optimizer, epoch)

        # evaluate on validation set
        err1, err5 = validate(val_loader, s_net, epoch)

        # remember best prec@1 and save checkpoint
        is_best = err1 <= best_err1
        best_err1 = min(err1, best_err1)
        if is_best:
            best_err5 = err5
        print('Current best accuracy (top-1 and 5 error):', best_err1, best_err5)
        save_checkpoint({
            'epoch': epoch,
            'state_dict': s_net.module.state_dict(),
            'train_state_dict': d_net.module.state_dict(),
            'best_err1': best_err1,
            'best_err5': best_err5,
            'optimizer': optimizer.state_dict(),
        }, is_best)
        gc.collect()

    print('Best accuracy (top-1 and 5 error):', best_err1, best_err5)
def load_model(config, num_classes, dropout=None):
    if config['model']['type'] == 'resnet':
        if config['model']['arch'] == 'resnet18':
            net = ResNet.resnet18(pretrained=False,
                                  progress=False,
                                  num_classes=num_classes,
                                  dropout=dropout)
        elif config['model']['arch'] == 'resnet50':
            net = ResNet.resnet50(pretrained=False,
                                  progress=False,
                                  num_classes=num_classes,
                                  dropout=dropout)
        elif config['model']['arch'] == 'resnext50':
            net = ResNet.resnext50_32x4d(pretrained=False,
                                         progress=False,
                                         num_classes=num_classes,
                                         dropout=dropout)
        elif config['model']['arch'] == 'resnet50d':
            net = ResNetD.resnet50d(pretrained=False,
                                    progress=False,
                                    num_classes=num_classes,
                                    dropout=dropout)
        else:
            raise ValueError('Unsupported architecture: ' +
                             str(config['model']['arch']))
    elif config['model']['type'] == 'tresnet':
        if config['model']['arch'] == 'tresnetm':
            net = TResNet.TResnetM(num_classes=num_classes)
        elif config['model']['arch'] == 'tresnetl':
            net = TResNet.TResnetL(num_classes=num_classes)
        elif config['model']['arch'] == 'tresnetxl':
            net = TResNet.TResnetXL(num_classes=num_classes)
        else:
            raise ValueError('Unsupported architecture: ' +
                             str(config['model']['arch']))
    elif config['model']['type'] == 'regnet':
        regnet_config = dict()
        if config['model']['arch'] == 'regnetx-200mf':
            regnet_config['depth'] = 13
            regnet_config['w0'] = 24
            regnet_config['wa'] = 36.44
            regnet_config['wm'] = 2.49
            regnet_config['group_w'] = 8
            regnet_config['se_on'] = False
            regnet_config['num_classes'] = num_classes

            net = RegNet.RegNet(regnet_config)
        elif config['model']['arch'] == 'regnetx-600mf':
            regnet_config['depth'] = 16
            regnet_config['w0'] = 48
            regnet_config['wa'] = 36.97
            regnet_config['wm'] = 2.24
            regnet_config['group_w'] = 24
            regnet_config['se_on'] = False
            regnet_config['num_classes'] = num_classes

            net = RegNet.RegNet(regnet_config)
        elif config['model']['arch'] == 'regnetx-4.0gf':
            regnet_config['depth'] = 23
            regnet_config['w0'] = 96
            regnet_config['wa'] = 38.65
            regnet_config['wm'] = 2.43
            regnet_config['group_w'] = 40
            regnet_config['se_on'] = False
            regnet_config['num_classes'] = num_classes

            net = RegNet.RegNet(regnet_config)
        elif config['model']['arch'] == 'regnetx-6.4gf':
            regnet_config['depth'] = 17
            regnet_config['w0'] = 184
            regnet_config['wa'] = 60.83
            regnet_config['wm'] = 2.07
            regnet_config['group_w'] = 56
            regnet_config['se_on'] = False
            regnet_config['num_classes'] = num_classes

            net = RegNet.RegNet(regnet_config)
        elif config['model']['arch'] == 'regnety-200mf':
            regnet_config['depth'] = 13
            regnet_config['w0'] = 24
            regnet_config['wa'] = 36.44
            regnet_config['wm'] = 2.49
            regnet_config['group_w'] = 8
            regnet_config['se_on'] = True
            regnet_config['num_classes'] = num_classes

            net = RegNet.RegNet(regnet_config)
        elif config['model']['arch'] == 'regnety-600mf':
            regnet_config['depth'] = 15
            regnet_config['w0'] = 48
            regnet_config['wa'] = 32.54
            regnet_config['wm'] = 2.32
            regnet_config['group_w'] = 16
            regnet_config['se_on'] = True
            regnet_config['num_classes'] = num_classes

            net = RegNet.RegNet(regnet_config)
        elif config['model']['arch'] == 'regnety-4.0gf':
            regnet_config['depth'] = 22
            regnet_config['w0'] = 96
            regnet_config['wa'] = 31.41
            regnet_config['wm'] = 2.24
            regnet_config['group_w'] = 64
            regnet_config['se_on'] = True
            regnet_config['num_classes'] = num_classes

            net = RegNet.RegNet(regnet_config)
        elif config['model']['arch'] == 'regnety-6.4gf':
            regnet_config['depth'] = 25
            regnet_config['w0'] = 112
            regnet_config['wa'] = 33.22
            regnet_config['wm'] = 2.27
            regnet_config['group_w'] = 72
            regnet_config['se_on'] = True
            regnet_config['num_classes'] = num_classes

            net = RegNet.RegNet(regnet_config)
        else:
            raise ValueError('Unsupported architecture: ' +
                             str(config['model']['arch']))
    elif config['model']['type'] == 'resnest':
        if config['model']['arch'] == 'resnest50':
            net = ResNest.resnest50(pretrained=False, num_classes=num_classes)
        elif config['model']['arch'] == 'resnest101':
            net = ResNest.resnest101(pretrained=False, num_classes=num_classes)
        else:
            raise ValueError('Unsupported architecture: ' +
                             str(config['model']['arch']))
    elif config['model']['type'] == 'efficient':
        if config['model']['arch'] == 'b0':
            net = EfficientNet.efficientnet_b0(pretrained=False,
                                               num_classes=num_classes)
        elif config['model']['arch'] == 'b1':
            net = EfficientNet.efficientnet_b1(pretrained=False,
                                               num_classes=num_classes)
        elif config['model']['arch'] == 'b2':
            net = EfficientNet.efficientnet_b2(pretrained=False,
                                               num_classes=num_classes)
        elif config['model']['arch'] == 'b3':
            net = EfficientNet.efficientnet_b3(pretrained=False,
                                               num_classes=num_classes)
        elif config['model']['arch'] == 'b4':
            net = EfficientNet.efficientnet_b4(pretrained=False,
                                               num_classes=num_classes)
        elif config['model']['arch'] == 'b5':
            net = EfficientNet.efficientnet_b5(pretrained=False,
                                               num_classes=num_classes)
        elif config['model']['arch'] == 'b6':
            net = EfficientNet.efficientnet_b6(pretrained=False,
                                               num_classes=num_classes)
        else:
            raise ValueError('Unsupported architecture: ' +
                             str(config['model']['arch']))
    elif config['model']['type'] == 'assembled':
        raise ValueError('Unsupported architecture: ' +
                         str(config['model']['arch']))
        pass
    elif config['model']['type'] == 'shufflenet':
        if config['model']['arch'] == 'v2_x0_5':
            net = ShuffleNetV2.shufflenet_v2_x0_5(pretrained=False,
                                                  progress=False,
                                                  num_classes=num_classes)
        elif config['model']['arch'] == 'v2_x1_0':
            net = ShuffleNetV2.shufflenet_v2_x1_0(pretrained=False,
                                                  progress=False,
                                                  num_classes=num_classes)
        elif config['model']['arch'] == 'v2_x1_5':
            net = ShuffleNetV2.shufflenet_v2_x1_5(pretrained=False,
                                                  progress=False,
                                                  num_classes=num_classes)
        elif config['model']['arch'] == 'v2_x2_0':
            net = ShuffleNetV2.shufflenet_v2_x2_0(pretrained=False,
                                                  progress=False,
                                                  num_classes=num_classes)
        else:
            raise ValueError('Unsupported architecture: ' +
                             str(config['model']['arch']))
    elif config['model']['type'] == 'mobilenet':
        if config['model']['arch'] == 'small_075':
            net = Mobilenetv3.mobilenetv3_small_075(pretrained=False,
                                                    num_classes=num_classes)
        elif config['model']['arch'] == 'small_100':
            net = Mobilenetv3.mobilenetv3_small_100(pretrained=False,
                                                    num_classes=num_classes)
        elif config['model']['arch'] == 'large_075':
            net = Mobilenetv3.mobilenetv3_large_075(pretrained=False,
                                                    num_classes=num_classes)
        elif config['model']['arch'] == 'large_100':
            net = Mobilenetv3.mobilenetv3_large_100(pretrained=False,
                                                    num_classes=num_classes)
        else:
            raise ValueError('Unsupported architecture: ' +
                             str(config['model']['arch']))
    elif config['model']['type'] == 'rexnet':
        if config['model']['arch'] == 'rexnet1.0x':
            net = ReXNet.rexnet(num_classes=num_classes, width_multi=1.0)
        elif config['model']['arch'] == 'rexnet1.5x':
            net = ReXNet.rexnet(num_classes=num_classes, width_multi=1.5)
        elif config['model']['arch'] == 'rexnet2.0x':
            net = ReXNet.rexnet(num_classes=num_classes, width_multi=2.0)
        else:
            raise ValueError('Unsupported architecture: ' +
                             str(config['model']['arch']))
    elif config['model']['type'] == 'arcface':
        if config['model']['arch'] == 'arcface_resnet18':
            net = Arcface.arcmargin_resnet18(pretrained=False,
                                             progress=False,
                                             num_classes=num_classes,
                                             dropout=dropout,
                                             scale=64.0,
                                             margin=0.25)
        elif config['model']['arch'] == 'arcface_resnet50':
            net = Arcface.arcmargin_resnet50(pretrained=False,
                                             progress=False,
                                             num_classes=num_classes,
                                             dropout=dropout,
                                             scale=64.0,
                                             margin=0.25)
        elif config['model']['arch'] == 'arcface_resnext50':
            net = Arcface.arcmargin_resnext50_32x4d(pretrained=False,
                                                    progress=False,
                                                    num_classes=num_classes,
                                                    dropout=dropout,
                                                    scale=64.0,
                                                    margin=0.25)
        else:
            raise ValueError('Unsupported architecture: ' +
                             str(config['model']['arch']))
    else:
        raise ValueError('Unsupported architecture: ' +
                         str(config['model']['type']))

    return net
예제 #4
0
        batch_size=args.batch_size, num_workers=1)
else:
    raise ValueError("Unknown dataset type")

assert args.model in ['VGG8', 'DenseNet40', 'ResNet18'], args.model
if args.model == 'VGG8':
    from models import VGG
    model = VGG.vgg8(args=args, logger=logger)
    criterion = wage_util.SSE()
elif args.model == 'DenseNet40':
    from models import DenseNet
    model = DenseNet.densenet40(args=args, logger=logger)
    criterion = wage_util.SSE()
elif args.model == 'ResNet18':
    from models import ResNet
    model = ResNet.resnet18(args=args, logger=logger)
    criterion = torch.nn.CrossEntropyLoss()
else:
    raise ValueError("Unknown model type")

if args.cuda:
    model.cuda()

optimizer = optim.SGD(model.parameters(),
                      lr=args.lr,
                      momentum=0.9,
                      weight_decay=0.0001)

decreasing_lr = list(map(int, args.decreasing_lr.split(',')))
logger('decreasing_lr: ' + str(decreasing_lr))
best_acc, old_file = 0, None
예제 #5
0
    
assert args.model in ['VGG8', 'DenseNet40', 'ResNet18'], args.model
if args.model == 'VGG8':
    from models import VGG
    model_path = './log/VGG8.pth'   # WAGE mode pretrained model
    modelCF = VGG.vgg8(args = args, logger=logger, pretrained = model_path)
elif args.model == 'DenseNet40':
    from models import DenseNet
    model_path = './log/DenseNet40.pth'     # WAGE mode pretrained model
    modelCF = DenseNet.densenet40(args = args, logger=logger, pretrained = model_path)
elif args.model == 'ResNet18':
    from models import ResNet
    # FP mode pretrained model, loaded from 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
    # model_path = './log/xxx.pth'
    # modelCF = ResNet.resnet18(args = args, logger=logger, pretrained = model_path)
    modelCF = ResNet.resnet18(args = args, logger=logger, pretrained = True)
else:
    raise ValueError("Unknown model type")

if args.cuda:
	modelCF.cuda()

best_acc, old_file = 0, None
t_begin = time.time()
# ready to go
modelCF.eval()

test_loss = 0
correct = 0
trained_with_quantization = True