Example #1
0
def get_model(train_model):

    if train_model == 'resnet18':
        return resnet.resnet18()
    elif train_model == 'resnet34':
        return resnet.resnet34()
    elif train_model == 'resnet50':
        return resnet.resnet50()
    elif train_model == 'resnet101':
        return resnet.resnet101()
    elif train_model == 'resnet152':
        return resnet.resnet152()
    elif train_model == 'resnet18_copy':
        return resnet_copy.resnet18()
    elif train_model == 'resnet34_copy':
        return resnet_copy.resnet34()
    elif train_model == 'resnet50_copy':
        return resnet_copy.resnet50()
    elif train_model == 'resnet101_copy':
        return resnet_copy.resnet101()
    elif train_model == 'resnet152':
        return resnet_copy.resnet152()
    elif train_model == 'vgg11':
        return vgg11()
    elif train_model == 'vgg13':
        return vgg13()
    elif train_model == 'vgg16':
        return vgg16()
    elif train_model == 'vgg19':
        return vgg19()
    elif train_model == 'nin':
        return nin()
    elif train_model == 'googlenet':
        return googlenet()
Example #2
0
def main():

    start_epoch = 0
    epochs = 30
    cuda = '0'
    res_train= []
    res_test = []

    # Use CUDA
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    use_cuda = torch.cuda.is_available()

    #Data
    trainLoader, testLoader = image_load(args.train_batch, args.test_batch)

    #Load Network
    if(args.arch == 'teacher'):
        teacher_net = resnet.resnet50(pretrained=True)

        #fine_tuning

        teacher_net.avgpool = nn.AvgPool2d(1, stride=1)
        teacher_net.fc = nn.Linear(2048, 1000)
        model = teacher_net

    elif(args.arch == 'student'):
        student_net = vgg.vgg11(pretrained=False)
        model = student_net

    model = model.cuda()

    #nn.CrossEntropyLoss에 softmax가 포함되어 있다.

    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    for epoch in range(start_epoch, epochs):

        trainTop1Val = train(trainLoader, model, criterion, optimizer, epoch , print_freq=100)
        testTop1Val = test(testLoader, model,criterion, epoch, print_freq=100)
        a = model.state_dict()
        b = optimizer.state_dict()
        state = {'epoch': epoch+1,
                 'arch': 'teacherNet',
                 'state_dict': model.state_dict(),
                 'optimizer': optimizer.state_dict()
                 }
        filename =  'teacherNet_'+'checkpoint.pth'
        torch.save(state, filename)

        res_train.append(trainTop1Val)
        res_test.append(testTop1Val)

    print(res_test, res_train)
    top1AccPlot(res_train, res_test, epoch)
Example #3
0
    def set_model(self, name):
        if name == 'VGG-11':
            self.model = vgg11(pretrained=False,
                               num_classes=self.num_classes,
                               class_size=self.class_size,
                               im_size=(1, 3, 32, 32),
                               device=self.device).to(self.device)

        if name == 'RESNET-50':
            self.model = resnet50(pretrained=False,
                                  num_classes=self.num_classes,
                                  device=self.device).to(self.device)
Example #4
0
def get_network(args,cfg):
    """ return given network
    """
    # pdb.set_trace()
    if args.net == 'lenet5':
        net = LeNet5().cuda()
    elif args.net == 'alexnet':
        net = alexnet(pretrained=args.pretrain, num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg16':
        net = vgg16(pretrained=args.pretrain, num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg13':
        net = vgg13(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg11':
        net = vgg11(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg19':
        net = vgg19(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg16_bn':
        net = vgg16_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg13_bn':
        net = vgg13_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg11_bn':
        net = vgg11_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg19_bn':
        net = vgg19_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net =='inceptionv3':
        net = inception_v3().cuda()
    # elif args.net == 'inceptionv4':
    #     net = inceptionv4().cuda()
    # elif args.net == 'inceptionresnetv2':
    #     net = inception_resnet_v2().cuda()
    elif args.net == 'resnet18':
        net = resnet18(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda(args.gpuid)
    elif args.net == 'resnet34':
        net = resnet34(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'resnet50':
        net = resnet50(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda(args.gpuid)
    elif args.net == 'resnet101':
        net = resnet101(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'resnet152':
        net = resnet152(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'squeezenet':
        net = squeezenet1_0().cuda()
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    return net
Example #5
0
def get_network(args):
    """ return given network
    """
    if args.task == 'cifar10':
        nclass = 10
    elif args.task == 'cifar100':
        nclass = 100
    #Yang added none bn vggs
    if args.net == 'vgg11':
        from models.vgg import vgg11
        net = vgg11(num_classes=nclass)
    elif args.net == 'vgg13':
        from models.vgg import vgg13
        net = vgg13(num_classes=nclass)
    elif args.net == 'vgg16':
        from models.vgg import vgg16
        net = vgg16(num_classes=nclass)
    elif args.net == 'vgg19':
        from models.vgg import vgg19
        net = vgg19(num_classes=nclass) 
    
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34(num_classes=nclass)
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50(num_classes=nclass)
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101(num_classes=nclass)
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152(num_classes=nclass)

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu: #use_gpu
        net = net.cuda()

    return net
Example #6
0
def get_network(args):

    if args.net == 'vgg16':
        from models.vgg import vgg16
        net = vgg16()

    elif args.net == 'vgg11':
        from models.vgg import vgg11
        net = vgg11()

    elif args.net == 'vgg13':
        from models.vgg import vgg13
        net = vgg13()

    elif args.net == 'vgg19':
        from models.vgg import vgg19
        net = vgg19()

    return net
Example #7
0
def get_model():
    model_name = str(FLAGS.model).lower()
    aDict = get_input_info()
    input_shape, num_class = aDict['input_shape'], aDict['num_class']
    if model_name == 'logistic':
        return Logistic(input_shape, num_class)
    elif model_name == '2nn':
        return TwoHiddenLayerFc(input_shape, num_class)
    elif model_name == 'cnn':
        return TwoConvOneFc(input_shape, num_class)
    elif model_name == 'ccnn':
        return CifarCnn(input_shape, num_class)
    elif model_name == 'lenet':
        return LeNet(input_shape, num_class)
    elif model_name == 'lstm':
        return lstm(input_shape, num_class)
    elif model_name == 'resnet18':
        return resnet.resnet18(pretrained=False, progress=False, device='cpu')
    elif model_name == 'vgg11':
        return vgg.vgg11(pretrained=False, progress=False, device='cpu')
    else:
        raise ValueError("Not support model: {}!".format(model_name))
Example #8
0
def get_model(args, model_path=None):
    """

    :param args: super arguments
    :param model_path: if not None, load already trained model parameters.
    :return: model
    """
    if args.scratch:  # train model from scratch
        pretrained = False
        model_dir = None
        print("=> Loading model '{}' from scratch...".format(args.model))
    else:  # train model with pretrained model
        pretrained = True
        model_dir = os.path.join(args.root_path, args.pretrained_models_path)
        print("=> Loading pretrained model '{}'...".format(args.model))

    if args.model.startswith('resnet'):

        if args.model == 'resnet18':
            model = resnet18(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet34':
            model = resnet34(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet50':
            model = resnet50(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet101':
            model = resnet101(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet152':
            model = resnet152(pretrained=pretrained, model_dir=model_dir)

        model.fc = nn.Linear(model.fc.in_features, args.num_classes)

    elif args.model.startswith('vgg'):
        if args.model == 'vgg11':
            model = vgg11(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg11_bn':
            model = vgg11_bn(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg13':
            model = vgg13(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg13_bn':
            model = vgg13_bn(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg16':
            model = vgg16(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg16_bn':
            model = vgg16_bn(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg19':
            model = vgg19(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg19_bn':
            model = vgg19_bn(pretrained=pretrained, model_dir=model_dir)

        model.classifier[6] = nn.Linear(model.classifier[6].in_features, args.num_classes)

    elif args.model == 'alexnet':
        model = alexnet(pretrained=pretrained, model_dir=model_dir)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features, args.num_classes)

    # Load already trained model parameters and go on training
    if model_path is not None:
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model'])

    return model
Example #9
0
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from back import Bone, utils
from models.vgg import vgg11
from datasets import cat_dog

data_dir = 'train'
num_classes = 2
batch_size = 32
epochs_count = 20
num_workers = 8

datasets = cat_dog.get_datasets(data_dir)
model = vgg11(num_classes, batch_norm=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
criterion = nn.CrossEntropyLoss()

backbone = Bone(model,
                datasets,
                criterion,
                optimizer,
                scheduler,
                scheduler_after_ep=False,
                metric_fn=utils.accuracy_metric,
Example #10
0
def test_model(args):
    # create model
    num_classes = 2
    if args.arch == 'efficientnet_b0':
        if args.pretrained:
            model = EfficientNet.from_pretrained("efficientnet-b0",
                                                 quantize=args.quantize,
                                                 num_classes=num_classes)
        else:
            model = EfficientNet.from_name(
                "efficientnet-b0",
                quantize=args.quantize,
                override_params={'num_classes': num_classes})
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'mobilenet_v1':
        model = mobilenet_v1(quantize=args.quantize, num_classes=num_classes)
        model = torch.nn.DataParallel(model).cuda()

        if args.pretrained:
            checkpoint = torch.load(args.resume)
            state_dict = checkpoint['state_dict']

            if num_classes != 1000:
                new_dict = {
                    k: v
                    for k, v in state_dict.items() if 'fc' not in k
                }
                state_dict = new_dict

            res = model.load_state_dict(state_dict, strict=False)

            for missing_key in res.missing_keys:
                assert 'quantize' in missing_key or 'fc' in missing_key

    elif args.arch == 'mobilenet_v2':
        model = mobilenet_v2(pretrained=args.pretrained,
                             num_classes=num_classes,
                             quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'resnet18':
        model = resnet18(pretrained=args.pretrained,
                         num_classes=num_classes,
                         quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'resnet50':
        model = resnet50(pretrained=args.pretrained,
                         num_classes=num_classes,
                         quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'resnet152':
        model = resnet152(pretrained=args.pretrained,
                          num_classes=num_classes,
                          quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'resnet164':
        model = resnet_164(num_classes=num_classes, quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'vgg11':
        model = vgg11(pretrained=args.pretrained,
                      num_classes=num_classes,
                      quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'vgg19':
        model = vgg19(pretrained=args.pretrained,
                      num_classes=num_classes,
                      quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    else:
        logging.info('No such model.')
        sys.exit()

    if args.resume and not args.pretrained:
        if os.path.isfile(args.resume):
            logging.info('=> loading checkpoint `{}`'.format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            logging.info('=> loaded checkpoint `{}` (epoch: {})'.format(
                args.resume, checkpoint['epoch']))
        else:
            logging.info('=> no checkpoint found at `{}`'.format(args.resume))

    cudnn.benchmark = False
    test_loader = prepare_test_data(dataset=args.dataset,
                                    datadir=args.datadir,
                                    batch_size=args.batch_size,
                                    shuffle=False,
                                    num_workers=args.workers)
    criterion = nn.CrossEntropyLoss().cuda()

    with torch.no_grad():
        prec1 = validate(args, test_loader, model, criterion, 0)
Example #11
0
def get_network(args):
    """ return given network
    """
    if args.task == 'cifar10':
        nclass = 10
    elif args.task == 'cifar100':
        nclass = 100
    #Yang added none bn vggs
    if args.net == 'vgg16':
        from models.vgg import vgg16
        net = vgg16(num_classes=nclass)
    elif args.net == 'vgg13':
        from models.vgg import vgg13
        net = vgg13(num_classes=nclass)
    elif args.net == 'vgg11':
        from models.vgg import vgg11
        net = vgg11(num_classes=nclass)
    elif args.net == 'vgg19':
        from models.vgg import vgg19
        net = vgg19(num_classes=nclass)

    elif args.net == 'vgg16bn':
        from models.vgg import vgg16_bn
        net = vgg16_bn(num_classes=nclass)
    elif args.net == 'vgg13bn':
        from models.vgg import vgg13_bn
        net = vgg13_bn(num_classes=nclass)
    elif args.net == 'vgg11bn':
        from models.vgg import vgg11_bn
        net = vgg11_bn(num_classes=nclass)
    elif args.net == 'vgg19bn':
        from models.vgg import vgg19_bn
        net = vgg19_bn(num_classes=nclass)

    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet(num_classes=nclass)
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception(num_classes=nclass)
    elif args.net == 'scnet':
        from models.sphereconvnet import sphereconvnet
        net = sphereconvnet(num_classes=nclass)
    elif args.net == 'sphereresnet18':
        from models.sphereconvnet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'sphereresnet32':
        from models.sphereconvnet import sphereresnet32
        net = sphereresnet32(num_classes=nclass)
    elif args.net == 'plainresnet32':
        from models.sphereconvnet import plainresnet32
        net = plainresnet32(num_classes=nclass)
    elif args.net == 'ynet18':
        from models.ynet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'ynet34':
        from models.ynet import resnet34
        net = resnet34(num_classes=nclass)
    elif args.net == 'ynet50':
        from models.ynet import resnet50
        net = resnet50(num_classes=nclass)
    elif args.net == 'ynet101':
        from models.ynet import resnet101
        net = resnet101(num_classes=nclass)
    elif args.net == 'ynet152':
        from models.ynet import resnet152
        net = resnet152(num_classes=nclass)

    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34(num_classes=nclass)
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50(num_classes=nclass)
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101(num_classes=nclass)
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152(num_classes=nclass)
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18(num_classes=nclass)
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34(num_classes=nclass)
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50(num_classes=nclass)
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101(num_classes=nclass)
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152(num_classes=nclass)
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50(num_classes=nclass)
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101(num_classes=nclass)
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152(num_classes=nclass)
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet(num_classes=nclass)
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2(num_classes=nclass)
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet(num_classes=nclass)
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18(num_classes=nclass)
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34(num_classes=nclass)
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50(num_classes=nclass)
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101(num_classes=nclass)
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152(num_classes=nclass)

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:  #use_gpu
        net = net.cuda()

    return net
Example #12
0
    elif args.arch == 'mobilenet_v1':
        model = mobilenet_v1(num_classes=num_classes)

    elif args.arch == 'mobilenet_v2':
        model = mobilenet_v2(num_classes=num_classes)

    elif args.arch == 'resnet18':
        model = resnet18(num_classes=num_classes)

    elif args.arch == 'resnet50':
        model = resnet50(num_classes=num_classes)

    elif args.arch == 'resnet152':
        model = resnet152(num_classes=num_classes)

    elif args.arch == 'resnet164':
        model = resnet_164(num_classes=num_classes)

    elif args.arch == 'vgg11':
        model = vgg11(num_classes=num_classes)

    elif args.arch == 'vgg19':
        model = vgg19(num_classes=num_classes)

    else:
        print('No such model.')
        sys.exit()

    count_model_param_flops(model, input_res=input_res)
Example #13
0
def main():
    global args, best_err1
    args = parser.parse_args()

    # TensorBoard configure
    if args.tensorboard:
        configure('%s_checkpoints/%s'%(args.dataset, args.expname))

    # CUDA
    os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_ids)
    if torch.cuda.is_available():
        cudnn.benchmark = True  # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        kwargs = {'num_workers': 2, 'pin_memory': True}
    else:
        kwargs = {'num_workers': 2}

    # Data loading code
    if args.dataset == 'cifar10':
        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                         std=[0.2023, 0.1994, 0.2010])
    elif args.dataset == 'cifar100':
        normalize = transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
                                         std=[0.2634, 0.2528, 0.2719])
    elif args.dataset == 'cub':
        normalize = transforms.Normalize(mean=[0.4862, 0.4973, 0.4293],
                                         std=[0.2230, 0.2185, 0.2472])
    elif args.dataset == 'webvision':
        normalize = transforms.Normalize(mean=[0.49274242, 0.46481857, 0.41779366],
                                         std=[0.26831809, 0.26145372, 0.27042758])
    else:
        raise Exception('Unknown dataset: {}'.format(args.dataset))

    # Transforms
    if args.augment:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(args.train_image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(args.train_image_size),
            transforms.ToTensor(),
            normalize,
        ])
    val_transform = transforms.Compose([
        transforms.Resize(args.test_image_size),
        transforms.CenterCrop(args.test_crop_image_size),
        transforms.ToTensor(),
        normalize
    ])

    # Datasets
    num_classes = 10    # default 10 classes
    if args.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10('./data/', train=True, download=True, transform=train_transform)
        val_dataset = datasets.CIFAR10('./data/', train=False, download=True, transform=val_transform)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_dataset = datasets.CIFAR100('./data/', train=True, download=True, transform=train_transform)
        val_dataset = datasets.CIFAR100('./data/', train=False, download=True, transform=val_transform)
        num_classes = 100
    elif args.dataset == 'cub':
        train_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/DuAngAng/datasets/CUB-200-2011/train/',
                                             transform=train_transform)
        val_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/DuAngAng/datasets/CUB-200-2011/test/',
                                           transform=val_transform)
        num_classes = 200
    elif args.dataset == 'webvision':
        train_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/LiuJing/WebVision/info/train',
                                             transform=train_transform)
        val_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/LiuJing/WebVision/info/val',
                                           transform=val_transform)
        num_classes = 1000
    else:
        raise Exception('Unknown dataset: {}'.format(args.dataset))

    # Data Loader
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, **kwargs)

    # Create model
    if args.model == 'AlexNet':
        model = alexnet(pretrained=False, num_classes=num_classes)
    elif args.model == 'VGG':
        use_batch_normalization = True  # default use Batch Normalization
        if use_batch_normalization:
            if args.depth == 11:
                model = vgg11_bn(pretrained=False, num_classes=num_classes)
            elif args.depth == 13:
                model = vgg13_bn(pretrained=False, num_classes=num_classes)
            elif args.depth == 16:
                model = vgg16_bn(pretrained=False, num_classes=num_classes)
            elif args.depth == 19:
                model = vgg19_bn(pretrained=False, num_classes=num_classes)
            else:
                raise Exception('Unsupport VGG detph: {}, optional depths: 11, 13, 16 or 19'.format(args.depth))
        else:
            if args.depth == 11:
                model = vgg11(pretrained=False, num_classes=num_classes)
            elif args.depth == 13:
                model = vgg13(pretrained=False, num_classes=num_classes)
            elif args.depth == 16:
                model = vgg16(pretrained=False, num_classes=num_classes)
            elif args.depth == 19:
                model = vgg19(pretrained=False, num_classes=num_classes)
            else:
                raise Exception('Unsupport VGG detph: {}, optional depths: 11, 13, 16 or 19'.format(args.depth))
    elif args.model == 'Inception':
        model = inception_v3(pretrained=False, num_classes=num_classes)
    elif args.model == 'ResNet':
        if args.depth == 18:
            model = resnet18(pretrained=False, num_classes=num_classes)
        elif args.depth == 34:
            model = resnet34(pretrained=False, num_classes=num_classes)
        elif args.depth == 50:
            model = resnet50(pretrained=False, num_classes=num_classes)
        elif args.depth == 101:
            model = resnet101(pretrained=False, num_classes=num_classes)
        elif args.depth == 152:
            model = resnet152(pretrained=False, num_classes=num_classes)
        else:
            raise Exception('Unsupport ResNet detph: {}, optional depths: 18, 34, 50, 101 or 152'.format(args.depth))
    elif args.model == 'MPN-COV-ResNet':
        if args.depth == 18:
            model = mpn_cov_resnet18(pretrained=False, num_classes=num_classes)
        elif args.depth == 34:
            model = mpn_cov_resnet34(pretrained=False, num_classes=num_classes)
        elif args.depth == 50:
            model = mpn_cov_resnet50(pretrained=False, num_classes=num_classes)
        elif args.depth == 101:
            model = mpn_cov_resnet101(pretrained=False, num_classes=num_classes)
        elif args.depth == 152:
            model = mpn_cov_resnet152(pretrained=False, num_classes=num_classes)
        else:
            raise Exception('Unsupport MPN-COV-ResNet detph: {}, optional depths: 18, 34, 50, 101 or 152'.format(args.depth))
    else:
        raise Exception('Unsupport model'.format(args.model))

    # Get the number of model parameters
    print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))

    if torch.cuda.is_available():
        model = model.cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("==> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_err1 = checkpoint['best_err1']
            model.load_state_dict(checkpoint['state_dict'])
            print("==> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("==> no checkpoint found at '{}'".format(args.resume))

    print(model)

    # Define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    if torch.cuda.is_available():
        criterion = criterion.cuda()
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # Train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # Evaluate on validation set
        err1 = validate(val_loader, model, criterion, epoch)

        # Remember best err1 and save checkpoint
        is_best = (err1 <= best_err1)
        best_err1 = min(err1, best_err1)
        print("Current best accuracy (error):", best_err1)
        save_checkpoint({
            'epoch': epoch+1,
            'state_dict': model.state_dict(),
            'best_err1': best_err1,
        }, is_best)

    print("Best accuracy (error):", best_err1)
Example #14
0
def main_worker(gpu, args):
    global best_acc1
    global best_auc
    global minimum_loss
    global count
    global best_accdr
    args.gpu = gpu


    if not os.path.exists(args.model_dir):
        os.makedirs(args.model_dir)

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.arch == "vgg11":
        from models.vgg import vgg11
        model = vgg11(num_classes=args.num_class, crossCBAM=args.crossCBAM)
    elif args.arch == "resnet50":
        from models.resnet50 import resnet50
        model = resnet50(num_classes=args.num_class, multitask=args.multitask, liu=args.liu,
                 chen=args.chen, CAN_TS=args.CAN_TS, crossCBAM=args.crossCBAM,
                         crosspatialCBAM = args.crosspatialCBAM,  choice=args.choice)
    elif args.arch == "resnet34":
        from models.resnet50 import resnet34
        model = resnet34(num_classes=args.num_class, multitask=args.multitask, liu=args.liu,
                 chen=args.chen,CAN_TS=args.CAN_TS, crossCBAM=args.crossCBAM,
                         crosspatialCBAM = args.crosspatialCBAM)
    elif args.arch == "resnet18":
        from models.resnet50 import resnet18
        model = resnet18(num_classes=args.num_class, multitask=args.multitask, liu=args.liu,
                 chen=args.chen, flagCBAM=False, crossCBAM=args.crossCBAM)
    elif args.arch == "densenet161":
        from models.densenet import densenet161
        model = densenet161(num_classes=args.num_class, multitask=args.multitask, cosface=False, liu=args.liu,
                    chen=args.chen, crossCBAM=args.crossCBAM)
    elif args.arch == "wired":
        from models.wirednetwork import CNN
        model = CNN(args, num_classes=args.num_class)
    else:
        print ("no backbone model")

    if args.pretrained:
        print ("==> Load pretrained model")
        model_dict = model.state_dict()
        pretrain_path = {"resnet50": "pretrain/resnet50-19c8e357.pth",
                         "resnet34": "pretrain/resnet34-333f7ec4.pth",
                         "resnet18": "pretrain/resnet18-5c106cde.pth",
                         "densenet161": "pretrain/densenet161-8d451a50.pth",
                         "vgg11": "pretrain/vgg11-bbd30ac9.pth",
                         "densenet121": "pretrain/densenet121-a639ec97.pth"}[args.arch]
        pretrained_dict = torch.load(pretrain_path)
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        pretrained_dict.pop('classifier.weight', None)
        pretrained_dict.pop('classifier.bias', None)
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

    torch.cuda.set_device(args.gpu)
    model = model.cuda(args.gpu)



    print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    if args.adam:
        optimizer = torch.optim.Adam(model.parameters(), args.base_lr, weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(), args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,  map_location={'cuda:4':'cuda:0'})
            # args.start_epoch = checkpoint['epoch']

            #  load partial weights
            if not args.evaluate:
                print ("load partial weights")
                model_dict = model.state_dict()
                pretrained_dict = {k: v for k, v in checkpoint['state_dict'].items() if k in model_dict}
                model_dict.update(pretrained_dict)
                model.load_state_dict(model_dict)
            else:
                print("load whole weights")
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])

            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            exit(0)


    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    size  = 224

    tra = transforms.Compose([
                transforms.Resize(256),
                transforms.RandomResizedCrop(size),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                # transforms.RandomRotation(90),
                # transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
                transforms.ToTensor(),
                normalize,
            ])
    tra_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize])

    # tra = transforms.Compose([
    #     transforms.Resize(350),
    #     transforms.RandomHorizontalFlip(),
    #     transforms.RandomVerticalFlip(),
    #     # transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    #     transforms.RandomRotation([-180, 180]),
    #     transforms.RandomAffine([-180, 180], translate=[0.1, 0.1], scale=[0.7, 1.3]),
    #     transforms.RandomCrop(224),
    #     #            transforms.CenterCrop(224),
    #     transforms.ToTensor(),
    #     normalize
    # ])

    # print (args.model_dir)
    # tra = transforms.Compose([
    #     transforms.Resize(350),
    #     transforms.RandomHorizontalFlip(),
    #     transforms.RandomVerticalFlip(),
    #     # transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    #     transforms.RandomRotation([-180, 180]),
    #     transforms.RandomAffine([-180, 180], translate=[0.1, 0.1], scale=[0.7, 1.3]),
    #     transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    #     transforms.ToTensor(),
    #     normalize
    # ])
    # tra_test = transforms.Compose([
    #     transforms.Resize(350),
    #     transforms.CenterCrop(224),
    #     transforms.ToTensor(),
    #     normalize])

    if args.dataset == 'amd':
        from datasets.amd_dataset import traindataset
    elif args.dataset == 'pm':
        from datasets.pm_dataset import traindataset
    elif args.dataset == "drdme":
        from datasets.drdme_dataset import traindataset
    elif args.dataset == "missidor":
        from datasets.missidor import traindataset
    elif args.dataset == "kaggle":
        from datasets.kaggle import traindataset
    else:
        print ("no dataset")
        exit(0)

    val_dataset = traindataset(root=args.data, mode = 'val',
                               transform=tra_test, num_class=args.num_class,
                               multitask=args.multitask, args=args)



    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)



    if args.evaluate:
        a = time.time()
        # savedir = args.resume.replace("model_converge.pth.tar","")
        savedir = args.resume.replace(args.resume.split("/")[-1], "")
        # savedir = "./"
        if not args.multitask:
            acc, auc, precision_dr, recall_dr, f1score_dr  = validate(val_loader, model, args)
            result_list = [acc, auc, precision_dr, recall_dr, f1score_dr]
            print ("acc, auc, precision, recall, f1", acc, auc, precision_dr, recall_dr, f1score_dr)

            save_result_txt(savedir, result_list)
            print("time", time.time() - a)
            return
        else:
            acc_dr, acc_dme, acc_joint, other_results, se, sp = validate(val_loader, model, args)
            print ("acc_dr, acc_dme, acc_joint", acc_dr, acc_dme, acc_joint)
            exit(0)
            print ("auc_dr, auc_dme, precision_dr, precision_dme, recall_dr, recall_dme, f1score_dr, f1score_dme",
                   other_results)
            print ("se, sp", se, sp)
            result_list = [acc_dr, acc_dme, acc_joint]
            result_list += other_results
            result_list += [se, sp]
            save_result_txt(savedir, result_list)

            print ("time", time.time()-a)
            return

    train_dataset = traindataset(root=args.data, mode='train', transform=tra, num_class=args.num_class,
                                 multitask=args.multitask, args=args)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
            num_workers=args.workers, pin_memory=True,worker_init_fn=worker_init_fn)


    writer = SummaryWriter()
    writer.add_text('Text', str(args))
    #
    from lr_scheduler import LRScheduler
    lr_scheduler = LRScheduler(optimizer, len(train_loader), args)

    for epoch in range(args.start_epoch, args.epochs):
        is_best = False
        is_best_auc = False
        is_best_acc = False
        # lr = adjust_learning_rate(optimizer, epoch, args)
        # writer.add_scalar("lr", lr, epoch)
        # train for one epoch
        loss_train = train(train_loader, model, criterion, lr_scheduler, writer, epoch, optimizer, args)
        writer.add_scalar('Train loss', loss_train, epoch)

        # evaluate on validation set
        if epoch % 5 == 0:
            if args.dataset == "kaggle":
                acc_dr, auc_dr = validate(val_loader, model, args)
                writer.add_scalar("Val acc_dr", acc_dr, epoch)
                writer.add_scalar("Val auc_dr", auc_dr, epoch)
                is_best = acc_dr >= best_acc1
                best_acc1 = max(acc_dr, best_acc1)
            elif not args.multitask:
                acc, auc, precision, recall, f1 = validate(val_loader, model, args)
                writer.add_scalar("Val acc_dr", acc, epoch)
                writer.add_scalar("Val auc_dr", auc, epoch)
                is_best = auc >= best_acc1
                best_acc1 = max(auc, best_acc1)
            else:
                acc_dr, acc_dme, joint_acc, other_results, se, sp , losses = validate(val_loader, model, args,criterion)
                writer.add_scalar("Val acc_dr", acc_dr, epoch)
                writer.add_scalar("Val acc_dme", acc_dme, epoch)
                writer.add_scalar("Val acc_joint", joint_acc, epoch)
                writer.add_scalar("Val auc_dr", other_results[0], epoch)
                writer.add_scalar("Val auc_dme", other_results[1], epoch)
                writer.add_scalar("val loss", losses, epoch)
                is_best = joint_acc >= best_acc1
                best_acc1 = max(joint_acc, best_acc1)

                is_best_auc = other_results[0] >= best_auc
                best_auc = max(other_results[0], best_auc)

                is_best_acc = acc_dr >= best_accdr
                best_accdr = max(acc_dr, best_accdr)

        if not args.invalid:
            if is_best:
                save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
                }, is_best, filename = "model_converge.pth.tar", save_dir=args.model_dir)

            if is_best_auc:
                save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_auc,
                'optimizer' : optimizer.state_dict(),
                }, False, filename = "converge_auc.pth.tar", save_dir=args.model_dir)

            if is_best_acc:
                save_checkpoint({
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_accdr,
                    'optimizer': optimizer.state_dict(),
                }, False, filename="converge_acc.pth.tar", save_dir=args.model_dir)
Example #15
0
def main_worker(gpu, args):
    global best_acc1
    global minimum_loss
    global count
    args.gpu = gpu

    if not os.path.exists(args.model_dir):
        os.makedirs(args.model_dir)

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.arch == "vgg11":
        from models.vgg import vgg11
        model = vgg11(num_classes=args.num_class, crossCBAM=args.crossCBAM)
    elif args.arch == "resnet50":
        from models.resnet50 import resnet50
        model = resnet50(num_classes=args.num_class,
                         multitask=args.multitask,
                         liu=args.liu,
                         chen=args.chen,
                         flagCBAM=False,
                         crossCBAM=args.crossCBAM)
    elif args.arch == "resnet34":
        from models.resnet50 import resnet34
        model = resnet34(num_classes=args.num_class,
                         multitask=args.multitask,
                         liu=args.liu,
                         chen=args.chen,
                         flagCBAM=False,
                         crossCBAM=args.crossCBAM)
    elif args.arch == "resnet18":
        from models.resnet50 import resnet18
        model = resnet18(num_classes=args.num_class,
                         multitask=args.multitask,
                         liu=args.liu,
                         chen=args.chen,
                         flagCBAM=False,
                         crossCBAM=args.crossCBAM)
    elif args.arch == "densenet161":
        from models.densenet import densenet161
        model = densenet161(num_classes=args.num_class,
                            multitask=args.multitask,
                            cosface=False,
                            liu=args.liu,
                            chen=args.chen)
    elif args.arch == "wired":
        from models.wirednetwork import CNN
        model = CNN(args, num_classes=args.num_class)
    else:
        print("no backbone model")

    if args.pretrained:
        print("==> Load pretrained model")
        model_dict = model.state_dict()
        pretrain_path = {
            "resnet50": "pretrain/resnet50-19c8e357.pth",
            "resnet34": "pretrain/resnet34-333f7ec4.pth",
            "resnet18": "pretrain/resnet18-5c106cde.pth",
            "densenet161": "pretrain/densenet161-8d451a50.pth",
            "vgg11": "pretrain/vgg11-bbd30ac9.pth"
        }[args.arch]
        pretrained_dict = torch.load(pretrain_path)
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        pretrained_dict.pop('classifier.weight', None)
        pretrained_dict.pop('classifier.bias', None)
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

    torch.cuda.set_device(args.gpu)
    model = model.cuda(args.gpu)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr,
                                 weight_decay=args.weight_decay)
    # optimizer = torch.optim.SGD(model.parameters(), args.base_lr,
    #                             momentum=args.momentum,
    #                             weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location={'cuda:4': 'cuda:0'})
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            exit(0)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    size = 224
    # tra = transforms.Compose([
    #             # transforms.Resize(256),
    #             transforms.RandomResizedCrop(size),
    #             transforms.RandomHorizontalFlip(),
    #             transforms.RandomVerticalFlip(),
    #             # transforms.RandomRotation(90),
    #             # transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
    #             transforms.ToTensor(),
    #             normalize,
    #         ])
    # tra_test = transforms.Compose([
    #         transforms.Resize(size+32),
    #         transforms.CenterCrop(size),
    #         transforms.ToTensor(),
    #         normalize])

    tra = transforms.Compose([
        transforms.Resize(350),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        # transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
        transforms.RandomRotation([-180, 180]),
        transforms.RandomAffine([-180, 180],
                                translate=[0.1, 0.1],
                                scale=[0.7, 1.3]),
        transforms.RandomCrop(224),
        #            transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])

    # tra = transforms.Compose([
    #     transforms.Resize(350),
    #     transforms.RandomHorizontalFlip(),
    #     transforms.RandomVerticalFlip(),
    #     transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    #     transforms.ToTensor(),
    #     normalize])

    #
    tra_test = transforms.Compose([
        transforms.Resize(350),
        transforms.CenterCrop(224),
        transforms.ToTensor(), normalize
    ])
    #
    # from autoaugment import ImageNetPolicy
    # tra =transforms.Compose([
    #      transforms.RandomResizedCrop(224),
    #      transforms.RandomHorizontalFlip(),
    #      ImageNetPolicy(),
    #      transforms.ToTensor(),
    #      normalize])

    # image = PIL.Image.open(path)
    # policy = ImageNetPolicy()
    # transformed = policy(image)

    if args.dataset == 'amd':
        from datasets.amd_dataset import traindataset
    elif args.dataset == 'pm':
        from datasets.pm_dataset import traindataset
    elif args.dataset == "drdme":
        from datasets.drdme_dataset import traindataset
    elif args.dataset == "missidor":
        from datasets.missidor import traindataset
    else:
        print("no dataset")
        exit(0)

    if args.evaluate:
        # result = validate(val_loader, model, args)
        result = multi_validate(model, test_times, normalize, traindataset,
                                args)
        print("acc_dr, acc_dme, acc_joint", result)
        return

    val_dataset = traindataset(root=args.data,
                               mode='val',
                               transform=tra_test,
                               num_class=args.num_class,
                               multitask=args.multitask)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    train_dataset = traindataset(root=args.data,
                                 mode='train',
                                 transform=tra,
                                 num_class=args.num_class,
                                 multitask=args.multitask)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               worker_init_fn=worker_init_fn)

    writer = SummaryWriter()
    writer.add_text('Text', str(args))

    # from lr_scheduler import LRScheduler
    # lr_scheduler = LRScheduler(optimizer, len(train_loader), args)

    for epoch in range(args.start_epoch, args.epochs):
        is_best = False
        lr = adjust_learning_rate(optimizer, epoch, args)
        writer.add_scalar("lr", lr, epoch)
        # train for one epoch
        loss_train = train(train_loader, model, criterion, optimizer, args)
        writer.add_scalar('Train loss', loss_train, epoch)

        # evaluate on validation set
        if epoch % 20 == 0:
            acc_dr, acc_dme, joint_acc = validate(val_loader, model, args)
            writer.add_scalar("Val acc_dr", acc_dr, epoch)
            writer.add_scalar("Val acc_dme", acc_dme, epoch)
            writer.add_scalar("Val acc_joint", joint_acc, epoch)
            is_best = joint_acc >= best_acc1
            best_acc1 = max(joint_acc, best_acc1)

        if not args.invalid:
            if is_best:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'best_acc1': best_acc1,
                        'optimizer': optimizer.state_dict(),
                    },
                    is_best,
                    filename="checkpoint" + str(epoch) + ".pth.tar",
                    save_dir=args.model_dir)
Example #16
0
def get_model(class_num):
    if (MODEL_TYPE == 'alexnet'):
        model = alexnet.alexnet(pretrained=FINETUNE)
    elif (MODEL_TYPE == 'vgg'):
        if (MODEL_DEPTH_OR_VERSION == 11):
            model = vgg.vgg11(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 13):
            model = vgg.vgg13(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 16):
            model = vgg.vgg16(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 19):
            model = vgg.vgg19(pretrained=FINETUNE)
        else:
            print('Error : VGG should have depth of either [11, 13, 16, 19]')
            sys.exit(1)
    elif (MODEL_TYPE == 'squeezenet'):
        if (MODEL_DEPTH_OR_VERSION == 0 or MODEL_DEPTH_OR_VERSION == 'v0'):
            model = squeezenet.squeezenet1_0(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 1 or MODEL_DEPTH_OR_VERSION == 'v1'):
            model = squeezenet.squeezenet1_1(pretrained=FINETUNE)
        else:
            print('Error : Squeezenet should have version of either [0, 1]')
            sys.exit(1)
    elif (MODEL_TYPE == 'resnet'):
        if (MODEL_DEPTH_OR_VERSION == 18):
            model = resnet.resnet18(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 34):
            model = resnet.resnet34(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 50):
            model = resnet.resnet50(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 101):
            model = resnet.resnet101(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 152):
            model = resnet.resnet152(pretrained=FINETUNE)
        else:
            print(
                'Error : Resnet should have depth of either [18, 34, 50, 101, 152]'
            )
            sys.exit(1)
    elif (MODEL_TYPE == 'densenet'):
        if (MODEL_DEPTH_OR_VERSION == 121):
            model = densenet.densenet121(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 169):
            model = densenet.densenet169(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 161):
            model = densenet.densenet161(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 201):
            model = densenet.densenet201(pretrained=FINETUNE)
        else:
            print(
                'Error : Densenet should have depth of either [121, 169, 161, 201]'
            )
            sys.exit(1)
    elif (MODEL_TYPE == 'inception'):
        if (MODEL_DEPTH_OR_VERSION == 3 or MODEL_DEPTH_OR_VERSION == 'v3'):
            model = inception.inception_v3(pretrained=FINETUNE)
        else:
            print('Error : Inception should have version of either [3, ]')
            sys.exit(1)
    else:
        print(
            'Error : Network should be either [alexnet / squeezenet / vgg / resnet / densenet / inception]'
        )
        sys.exit(1)

    if (MODEL_TYPE == 'alexnet' or MODEL_TYPE == 'vgg'):
        num_ftrs = model.classifier[6].in_features
        feature_model = list(model.classifier.children())
        feature_model.pop()
        feature_model.append(nn.Linear(num_ftrs, class_num))
        model.classifier = nn.Sequential(*feature_model)
    elif (MODEL_TYPE == 'resnet' or MODEL_TYPE == 'inception'):
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, class_num)
    elif (MODEL_TYPE == 'densenet'):
        num_ftrs = model.classifier.in_features
        model.classifier = nn.Linear(num_ftrs, class_num)

    return model
Example #17
0
frames = np.transpose(frames, [0, 2, 1])
p = np.random.permutation(len(frames))
n_train = int(len(frames) * 0.5)
x_train, x_test = frames[p][:n_train], frames[p][n_train:]
y_train, y_test = act_labels[p][:n_train], act_labels[p][n_train:]
train_ds = torch.utils.data.TensorDataset(torch.tensor(x_train, dtype=torch.float), torch.tensor(y_train, dtype=torch.long))
test_ds = torch.utils.data.TensorDataset(torch.tensor(x_test, dtype=torch.float), torch.tensor(y_test, dtype=torch.long))
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_loader= torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False)
print('in_shape: {}'.format(x_train.shape[1:]))
print('n_train: {}'.format(n_train))
print('n_test: {}'.format(len(frames)-n_train))

# Models
# model = SimpleCNN(x_train.shape[1:], n_classes=5).to(device)
model = vgg11(in_channels=3, num_classes=5).to(device)

# Training
# weights = torch.tensor([np.sum(act_labels == i) for i in range(5)], dtype=torch.float).to(device)
# weights /= weights.max(0)[0]
# weights = 1 / weights
weights = None
print('weights: {}'.format(weights))
criterion = nn.CrossEntropyLoss(weight=weights, reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

train_loss_list = []
train_acc_list = []
test_loss_list = []
test_acc_list = []