Esempio n. 1
0
def main():
    global args, best_err1, best_err5
    args = parser.parse_args()


    if args.dataset.startswith('cifar'):
        normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])

        if args.dataset == 'cifar100':
            train_loader = torch.utils.data.DataLoader(
                datasets.CIFAR100('../data', train=True, download=True, transform=transform_train),
                batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR100('../data', train=False, transform=transform_test),
                batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
            numberofclass = 100
        elif args.dataset == 'cifar10':
            train_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10('../data', train=True, download=True, transform=transform_train),
                batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10('../data', train=False, transform=transform_test),
                batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
            numberofclass = 10
        else:
            raise Exception('unknown dataset: {}'.format(args.dataset))

    elif args.dataset == 'imagenet':
        traindir = os.path.join('/home/data/ILSVRC/train')
        valdir = os.path.join('/home/data/ILSVRC/val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        jittering = utils.ColorJitter(brightness=0.4, contrast=0.4,
                                      saturation=0.4)
        lighting = utils.Lighting(alphastd=0.1,
                                  eigval=[0.2175, 0.0188, 0.0045],
                                  eigvec=[[-0.5675, 0.7192, 0.4009],
                                          [-0.5808, -0.0045, -0.8140],
                                          [-0.5836, -0.6948, 0.4203]])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                jittering,
                lighting,
                normalize,
            ]))

        train_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
            num_workers=args.workers, pin_memory=True, sampler=train_sampler)

        val_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(valdir, transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)
        numberofclass = 1000

    else:
        raise Exception('unknown dataset: {}'.format(args.dataset))

    print("=> creating model '{}'".format(args.net_type))
    if args.net_type == 'resnet':
        model = RN.ResNet(args.dataset, args.depth, numberofclass, args.bottleneck)  # for ResNet
    elif args.net_type == 'pyramidnet':
        model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha, numberofclass,
                                args.bottleneck)
    else:
        raise Exception('unknown network architecture: {}'.format(args.net_type))

    model = torch.nn.DataParallel(model).cuda()

    print(model)
    print('the number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay, nesterov=True)


    cudnn.benchmark = True


    for epoch in range(0, args.epochs):

        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train_loss = train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        err1, err5, val_loss = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = err1 <= best_err1
        best_err1 = min(err1, best_err1)
        if is_best:
            best_err5 = err5

        print('Current best accuracy (top-1 and 5 error):', best_err1, best_err5)
        save_checkpoint({
            'epoch': epoch,
            'arch': args.net_type,
            'state_dict': model.state_dict(),
            'best_err1': best_err1,
            'best_err5': best_err5,
            'optimizer': optimizer.state_dict(),
        }, is_best)

    print('Best accuracy (top-1 and 5 error):', best_err1, best_err5)
Esempio n. 2
0
def main():
    global args, best_err1, best_err5, global_epoch_confusion, best_loss
    args = parser.parse_args()

    if args.dataset.startswith('cifar'):
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        transform_test = transforms.Compose([transforms.ToTensor(), normalize])

        if args.dataset == 'cifar100':
            train_loader = torch.utils.data.DataLoader(
                datasets.CIFAR100('../data',
                                  train=True,
                                  download=True,
                                  transform=transform_train),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR100('../data',
                                  train=False,
                                  transform=transform_test),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)
            numberofclass = 100
        elif args.dataset == 'cifar10':
            train_data = datasets.CIFAR10('../data',
                                          train=True,
                                          download=True,
                                          transform=transform_train)
            print(train_data.targets[:30])
            print(train_data.targets[:30])
            print(len(train_data))

            class_counts = [9.0, 1.0]
            num_samples = sum(class_counts)
            labels = [0, 0, ..., 0, 1]  #corresponding labels of samples

            class_weights = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
            class_weights[args.first] = args.weight
            class_weights[args.second] = args.weight
            print(class_weights)
            weights = [
                class_weights[train_data.targets[i]]
                for i in range(len(train_data))
            ]
            sampler = WeightedRandomSampler(torch.DoubleTensor(weights),
                                            len(train_data))
            train_loader = torch.utils.data.DataLoader(
                train_data,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True,
                sampler=sampler)
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10('../data',
                                 train=False,
                                 transform=transform_test),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)
            numberofclass = 10

        else:
            raise Exception('unknown dataset: {}'.format(args.dataset))
    else:
        raise Exception('unknown dataset: {}'.format(args.dataset))

    print("=> creating model '{}'".format(args.net_type))
    if args.net_type == 'resnet':
        model = RN.ResNet(args.dataset, args.depth, numberofclass,
                          args.bottleneck)  # for ResNet
    elif args.net_type == 'pyramidnet':
        model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha,
                                numberofclass, args.bottleneck)
    else:
        raise Exception('unknown network architecture: {}'.format(
            args.net_type))

    model = torch.nn.DataParallel(model).cuda()

    if os.path.isfile(args.pretrained):
        print("=> loading checkpoint '{}'".format(args.pretrained))
        checkpoint = torch.load(args.pretrained)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}'".format(args.pretrained))

    # print(model)
    print('the number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss(reduction='none').cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    cudnn.benchmark = True
    #validate(val_loader, model, criterion, 0)

    # for checking pre-trained model accuracy and confusion
    if args.checkmodel:
        global_epoch_confusion.append({})
        get_confusion(val_loader, model, criterion)
        # cat->dog confusion
        log_print(str(args.first) + " -> " + str(args.second))
        log_print(global_epoch_confusion[-1]["confusion"][(args.first,
                                                           args.second)])
        # dog->cat confusion
        log_print(str(args.second) + " -> " + str(args.first))
        log_print(global_epoch_confusion[-1]["confusion"][(args.second,
                                                           args.first)])
        exit()

    for epoch in range(0, args.epochs):
        global_epoch_confusion.append({})
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        err1, err5, val_loss = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint

        if epoch // (args.epochs * 0.75):
            is_best = err1 <= best_err1
            best_err1 = min(err1, best_err1)
            if is_best:
                best_err5 = err5
                best_err1 = err1

            print('Current best accuracy (top-1 and 5 error):', best_err1,
                  best_err5)
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': args.net_type,
                    'state_dict': model.state_dict(),
                    'best_err1': best_err1,
                    'best_err5': best_err5,
                    'optimizer': optimizer.state_dict(),
                }, is_best)

        get_confusion(val_loader, model, criterion, epoch)
        # cat->dog confusion
        log_print(str(args.first) + " -> " + str(args.second))
        log_print(global_epoch_confusion[-1]["confusion"][(args.first,
                                                           args.second)])
        # dog->cat confusion
        log_print(str(args.second) + " -> " + str(args.first))
        log_print(global_epoch_confusion[-1]["confusion"][(args.second,
                                                           args.first)])

    print('Best accuracy (top-1 and 5 error):', best_err1, best_err5)
    directory = "runs/%s/" % (args.expname)
    if not os.path.exists(directory):
        os.makedirs(directory)
    epoch_confusions = 'runs/%s/' % (args.expname) + \
        'epoch_confusion_' + args.expid
    np.save(epoch_confusions, global_epoch_confusion)
    log_print("")
    # output best model accuracy and confusion
    repaired_model = 'runs/%s/' % (args.expname) + 'model_best.pth.tar'
    if os.path.isfile(repaired_model):
        print("=> loading checkpoint '{}'".format(repaired_model))
        checkpoint = torch.load(repaired_model)
        model.load_state_dict(checkpoint['state_dict'])
        get_confusion(val_loader, model, criterion)
        # dog->cat confusion
        log_print(str(args.first) + " -> " + str(args.second))
        log_print(global_epoch_confusion[-1]["confusion"][(args.first,
                                                           args.second)])
        # cat->dog confusion
        log_print(str(args.second) + " -> " + str(args.first))
        log_print(global_epoch_confusion[-1]["confusion"][(args.second,
                                                           args.first)])
Esempio n. 3
0
def main():
    global args, best_err1, best_err5
    args = parser.parse_args()
    global directory
    directory = "runs/%s/" % (args.expname)
    if not os.path.exists(directory):
        os.makedirs(directory)

    log = open(os.path.join(directory, 'log.txt'), 'w')

    train_loader, val_loader, numberofclass = load_data(args)

    if args.pretrained:
        num_class_output = 8142

    print("=> creating model '{}'".format(args.net_type))
    if args.net_type == 'resnet':
        model = RN.ResNet(args.dataset, args.depth, numberofclass,
                          args.bottleneck)  # for ResNet
    elif args.net_type == 'pyramidnet':
        model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha,
                                numberofclass, args.bottleneck)
    else:
        raise Exception('unknown network architecture: {}'.format(
            args.net_type))

    # model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()

    if args.pretrained:
        if os.path.isfile(args.checkpoint):
            print_log("=> loading checkpoint '{}'".format(args.checkpoint),
                      log)
            checkpoint = torch.load(args.checkpoint)
            model.load_state_dict(checkpoint['state_dict'])
            print_log("=> loaded checkpoint '{}'".format(args.checkpoint), log)
            num_ftrs = model.module.fc.in_features
            model.module.fc = nn.Linear(num_ftrs, numberofclass).cuda()
        else:
            raise Exception("=> no checkpoint found at '{}'".format(
                args.checkpoint))

    # num_ftrs = model.module.fc.in_features
    # model.fc = nn.Linear(num_ftrs, numberofclass)

    print_log("=> network :\n {}".format(model), log)
    # print(num_ftrs)
    print_log(
        'the number of model parameters: {}'.format(
            sum([p.data.nelement() for p in model.parameters()])), log)
    # exit()
    # define loss function (criterion) and optimizer
    cifar_version = args.dataset[5:]

    if args.reweight:
        print_log(
            'using re-weighting using beta value : {}'.format(
                args.beta_reweight), log)
        # img_num_per_cls = [int(line.strip()) for line in open(
        # 'img_num_per_cls.txt', 'r')]
        img_num_per_cls = get_img_num_per_cls(cifar_version, args.imb_factor)
        effective_num = 1.0 - np.power(args.beta_reweight, img_num_per_cls)
        weights = (1.0 - args.beta_reweight) / np.array(effective_num)
        weights = weights / np.sum(weights) * int(numberofclass)
        weights = torch.tensor(weights).float()
        criterion = nn.CrossEntropyLoss(weight=weights).cuda()
    else:
        criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    cudnn.benchmark = True

    for epoch in range(0, args.epochs):

        adjust_learning_rate(optimizer, epoch)

        # train for one epoch

        train_loss = train(train_loader, model, criterion, optimizer, epoch,
                           log)

        # evaluate on validation set
        err1, err5, val_loss = validate(val_loader, model, criterion, epoch,
                                        log)

        # remember best prec@1 and save checkpoint
        is_best = err1 <= best_err1
        best_err1 = min(err1, best_err1)
        if is_best:
            best_err5 = err5

        print_log(
            'Current best accuracy (top-1 and 5 error): {} {}'.format(
                best_err1, best_err5), log)
        save_checkpoint(
            {
                'epoch': epoch,
                'arch': args.net_type,
                'state_dict': model.state_dict(),
                'best_err1': best_err1,
                'best_err5': best_err5,
                'optimizer': optimizer.state_dict(),
            }, is_best)

    print_log(
        'Best accuracy (top-1 and 5 error): {} {}'.format(
            best_err1, best_err5), log)
def main():
    global args, best_err1, best_err5, global_epoch_confusion, best_loss
    args = parser.parse_args()

    if args.dataset.startswith('cifar'):
        normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])

        if args.dataset == 'cifar100':
            train_loader = torch.utils.data.DataLoader(
                datasets.CIFAR100('../data', train=True,
                                  download=True, transform=transform_train),
                batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR100('../data', train=False,
                                  transform=transform_test),
                batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
            numberofclass = 100
        elif args.dataset == 'cifar10':
            train_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10('../data', train=True,
                                 download=True, transform=transform_train),
                batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10('../data', train=False,
                                 transform=transform_test),
                batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
            numberofclass = 10

            target_train_dataset = datasets.CIFAR10(
                '../data', train=True, download=True, transform=transform_train)
            target_train_dataset = get_dataset_from_specific_classes(
                target_train_dataset, args.first, args.second, args.third)
            target_test_dataset = datasets.CIFAR10(
                '../data', train=False, download=True, transform=transform_test)
            target_test_dataset = get_dataset_from_specific_classes(
                target_test_dataset, args.first, args.second, args.third)
            target_train_loader = torch.utils.data.DataLoader(target_train_dataset, batch_size=args.extra, shuffle=True,
                                                              num_workers=args.workers, pin_memory=True)
            target_val_loader = torch.utils.data.DataLoader(target_test_dataset, batch_size=args.extra, shuffle=True,
                                                            num_workers=args.workers, pin_memory=True)

        else:
            raise Exception('unknown dataset: {}'.format(args.dataset))
    else:
        raise Exception('unknown dataset: {}'.format(args.dataset))

    print("=> creating model '{}'".format(args.net_type))
    if args.net_type == 'resnet':
        model = RN.ResNet(args.dataset, args.depth,
                          numberofclass, args.bottleneck)  # for ResNet
    elif args.net_type == 'pyramidnet':
        model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha, numberofclass,
                                args.bottleneck)
    else:
        raise Exception(
            'unknown network architecture: {}'.format(args.net_type))

    model = torch.nn.DataParallel(model).cuda()

    if os.path.isfile(args.pretrained):
        print("=> loading checkpoint '{}'".format(args.pretrained))
        checkpoint = torch.load(args.pretrained)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}'".format(args.pretrained))

    # print(model)
    print('the number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # replace bn layer
    if args.replace:
        model.to('cpu')
        global glob_bn_count
        global glob_bn_total
        glob_bn_total = 0
        glob_bn_count = 0
        count_bn_layer(model)
        print("total bn layer: " + str(glob_bn_total))
        glob_bn_count = 0
        replace_bn(model)
        print(model)
        model = model.cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss(reduction='none').cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay, nesterov=True)

    cudnn.benchmark = True
    #validate(val_loader, model, criterion, 0)

    # for checking pre-trained model accuracy and confusion
    if args.checkmodel:
        global_epoch_confusion.append({})
        get_confusion(val_loader, model, criterion)
        confusion_matrix = global_epoch_confusion[-1]["confusion"]
        print(str((args.first, args.second, args.third)) + " triplet: " + 
            str(abs(confusion_matrix[(args.first, args.second)] - confusion_matrix[(args.first, args.third)])))
        print(str((args.first, args.second)) + ": " + str(confusion_matrix[(args.first, args.second)]))
        print(str((args.first, args.third)) + ": " + str(confusion_matrix[(args.first, args.third)]))
        exit()

    for epoch in range(0, args.epochs):
        global_epoch_confusion.append({})
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, target_train_loader,
              model, criterion, optimizer, epoch)

        # evaluate on validation set
        err1, err5, val_loss = validate(
            val_loader, target_val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint

        if epoch // (args.epochs * 0.75):
            is_best = err1 <= best_err1
            best_err1 = min(err1, best_err1)
            if is_best:
                best_err5 = err5
                best_err1 = err1

            print('Current best accuracy (top-1 and 5 error):',
                  best_err1, best_err5)
            save_checkpoint({
                'epoch': epoch,
                'arch': args.net_type,
                'state_dict': model.state_dict(),
                'best_err1': best_err1,
                'best_err5': best_err5,
                'optimizer': optimizer.state_dict(),
            }, is_best)

        get_confusion(val_loader, model, criterion, epoch)
        confusion_matrix = global_epoch_confusion[-1]["confusion"]
        #print("loss: " + str(global_epoch_confusion[-1]["loss"]))
        first_second = compute_confusion(confusion_matrix, args.first, args.second)
        first_third = compute_confusion(confusion_matrix, args.first, args.third)
        print(str((args.first, args.second, args.third)) + " triplet: " + 
            str(compute_bias(confusion_matrix, args.first, args.second, args.third)))
        print(str((args.first, args.second)) + ": " + str(first_second))
        print(str((args.first, args.third)) + ": " + str(first_third))

    print('Best accuracy (top-1 and 5 error):', best_err1, best_err5)
    directory = "runs/%s/" % (args.expname)
    if not os.path.exists(directory):
        os.makedirs(directory)
    epoch_confusions = 'runs/%s/' % (args.expname) + \
        'epoch_confusion_' + args.expid
    np.save(epoch_confusions, global_epoch_confusion)
    log_print("")
    # output best model accuracy and confusion
    repaired_model = 'runs/%s/' % (args.expname) + 'model_best.pth.tar'
    if os.path.isfile(repaired_model):
        print("=> loading checkpoint '{}'".format(repaired_model))
        checkpoint = torch.load(repaired_model)
        model.load_state_dict(checkpoint['state_dict'])
        get_confusion(val_loader, model, criterion)
        confusion_matrix = global_epoch_confusion[-1]["confusion"]
        #print("loss: " + str(global_epoch_confusion[-1]["loss"]))
        first_second = compute_confusion(confusion_matrix, args.first, args.second)
        first_third = compute_confusion(confusion_matrix, args.first, args.third)
        print(str((args.first, args.second, args.third)) + " triplet: " + 
            str(compute_bias(confusion_matrix, args.first, args.second, args.third)))
        print(str((args.first, args.second)) + ": " + str(first_second))
        print(str((args.first, args.third)) + ": " + str(first_third))
Esempio n. 5
0
def main():
    global args, best_err1, best_err5
    args = parser.parse_args()

    img = Image.open(args.img_path).convert('RGB')

    if args.dataset.startswith('cifar'):
        size = (32, 32)
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        transform_inference = transforms.Compose(
            [transforms.Resize(size),
             transforms.ToTensor(), normalize])

        img = transform_inference(img).unsqueeze(0)

        if args.dataset == 'cifar100':
            numberofclass = 100
        elif args.dataset == 'cifar10':
            numberofclass = 10
        else:
            raise Exception('unknown dataset: {}'.format(args.dataset))

    elif args.dataset == 'imagenet':

        valdir = os.path.join('/home/data/ILSVRC/val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        numberofclass = 1000

    else:
        raise Exception('unknown dataset: {}'.format(args.dataset))

    print("=> creating model '{}'".format(args.net_type))
    if args.net_type == 'resnet':
        model = RN.ResNet(args.dataset, args.depth, numberofclass,
                          args.bottleneck)  # for ResNet
    elif args.net_type == 'pyramidnet':
        model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha,
                                numberofclass, args.bottleneck)
    else:
        raise Exception('unknown network architecture: {}'.format(
            args.net_type))

    model = torch.nn.DataParallel(model).cuda()

    if os.path.isfile(args.pretrained):
        print("=> loading checkpoint '{}'".format(args.pretrained))
        checkpoint = torch.load(args.pretrained)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}'".format(args.pretrained))
    else:
        raise Exception("=> no checkpoint found at '{}'".format(
            args.pretrained))

    model.module.fc = Identity()

    cudnn.benchmark = True

    print(model)
Esempio n. 6
0
def main():
    global args, best_err1, best_err5
    args = parser.parse_args()

    if args.dataset.startswith('cifar'):
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        transform_test = transforms.Compose([transforms.ToTensor(), normalize])

        if args.dataset == 'cifar100':
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR100('../data',
                                  download=True,
                                  train=False,
                                  transform=transform_test),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)
            numberofclass = 100
        elif args.dataset == 'cifar10':
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10('../data',
                                 download=True,
                                 train=False,
                                 transform=transform_test),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)
            numberofclass = 10
        else:
            raise Exception('unknown dataset: {}'.format(args.dataset))

    elif args.dataset == 'imagenet':

        valdir = os.path.join('/home/data/ILSVRC/val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        numberofclass = 1000

    else:
        raise Exception('unknown dataset: {}'.format(args.dataset))

    print("=> creating model '{}'".format(args.net_type))
    if args.net_type == 'resnet':
        model = RN.ResNet(args.dataset, args.depth, numberofclass,
                          args.bottleneck)  # for ResNet
    elif args.net_type == 'pyramidnet':
        model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha,
                                numberofclass, args.bottleneck)
    else:
        raise Exception('unknown network architecture: {}'.format(
            args.net_type))

    model = torch.nn.DataParallel(model).cuda()

    if os.path.isfile(args.pretrained):
        print("=> loading checkpoint '{}'".format(args.pretrained))
        checkpoint = torch.load(args.pretrained)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}'".format(args.pretrained))
    else:
        raise Exception("=> no checkpoint found at '{}'".format(
            args.pretrained))

    print(model)
    print('the number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    cudnn.benchmark = True

    # evaluate on validation set
    err1, err5, val_loss = validate(val_loader, model, criterion)

    print('Accuracy (top-1 and 5 error):', err1, err5)
Esempio n. 7
0
def main():

    global args, best_err1, best_err5

    args = parser.parse_args()

    if args.dataset.startswith('cifar'):
        normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])

        if args.dataset == 'cifar100':
            train_loader = torch.utils.data.DataLoader(
                datasets.CIFAR100('../data', train=True, download=True, transform=transform_train),
                batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR100('../data', train=False, transform=transform_test),
                batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
            numberofclass = 100
        elif args.dataset == 'cifar10':
            train_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10('../data',train=True, download=True, transform=transform_train),
                batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)

            val_loader = torch.utils.data.DataLoader(
                        datasets.CIFAR10('../data', train=False, transform=transform_test),
                        batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
            numberofclass = 10
        else:
            raise Exception('unknown dataset : {}'.format(args.dataset))

    if args.net_type == 'resnet':
        model = RN.ResNet(args.dataset, args.depth, numberofclass, args.bottleneck)
    elif args.net_type == 'pyramidnet':
        model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha, numberofclass, args.bottleneck)
    else:
        raise Exception('unknown network architecture: {}'.format(args.net_type))


    model = torch.nn.DataParallel(model).cuda()

    #print(model)
    #print('the number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))

    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)

    cudnn.benchmark = True

    for epoch in range(0, args.epochs):
        
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train_loss = train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        err1, err5, val_loss = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = err1 <= best_err1
        best_err1 = min(err1, best_err1)
        if is_best:
            best_err5 = err5

        print('Current best accuracy (top-1 and 5 error):', best_err1, best_err5)
        save_checkpoint({
            'epoch': epoch,
            'arch': args.net_type,
            'state_dict': model.state_dict(),
            'best_err1': best_err1,
            'best_err5': best_err5,
            'optimizer': optimizer.state_dict(),
        }, is_best)

    print('Best accuracy (top-1 and 5 error):', best_err1, best_err5)