Ejemplo n.º 1
0
def loadModel(path):
    state_dict = torch.load(path)['state_dict']
    model = models.ConvClassifier()
    model.load_state_dict(state_dict)
    model.cpu()
    model.eval()

    def infer(images):
        if (len(images) == 0):
            return []
        data = torch.zeros((len(images), 3, 32, 32))
        for i, img in enumerate(images):
            img = cv2.resize(img, (32, 32))
            data[i] = torch.from_numpy(
                np.transpose(img.astype(np.float32), axes=[2, 0, 1]) / 255)
        #print(data.shape)
        res = model(data)
        #print(res.shape)
        res = torch.argmax(res, dim=1)
        res = [dataset.index2name(t.item()) for t in res]
        return res

    return infer
def main(args):
    # hard coded values
    in_channels = 3  # rgb channels of orignal image fed to rotnet
    if args.layer == 1:
        in_features = 96
    else:
        in_features = 192
    rot_classes = 4
    out_classes = 10
    lr_decay_rate = 0.2  # lr is multiplied by decay rate after a milestone epoch is reached
    mult = 1  # data become mult times
    ####################

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((125.3 / 255, 123.0 / 255, 113.9 / 255),
                             (63.0 / 255, 62.1 / 255, 66.7 / 255))
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((125.3 / 255, 123.0 / 255, 113.9 / 255),
                             (63.0 / 255, 62.1 / 255, 66.7 / 255))
    ])

    trainset = datasets.CIFAR10(root='results/',
                                train=True,
                                download=True,
                                transform=train_transform)
    testset = datasets.CIFAR10(root='results/',
                               train=False,
                               download=True,
                               transform=test_transform)

    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=0)
    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=0)

    rot_network = mdl.RotNet(in_channels=in_channels,
                             num_nin_blocks=args.nins,
                             out_classes=rot_classes).to(args.device)
    class_network = mdl.ConvClassifier(in_channels=in_features,
                                       out_classes=out_classes).to(args.device)

    if args.opt == 'adam':
        optimizer = optim.Adam(class_network.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)
    else:
        optimizer = optim.SGD(class_network.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.milestones,
                                               gamma=lr_decay_rate)

    ####################################### Saving information
    results_dict = {}
    # These will store the values for best test accuracy model
    results_dict['train_loss'] = -1
    results_dict['train_acc'] = -1
    results_dict['test_loss'] = -1
    results_dict['test_acc'] = -1
    results_dict['best_acc_epoch'] = -1
    # For storing training history
    results_dict['train_loss_hist'] = []
    results_dict['train_acc_hist'] = []
    results_dict['test_loss_hist'] = []
    results_dict['test_acc_hist'] = []

    # directories to save models
    checkpoint_path = os.path.join(args.results_dir, 'model.pth')
    checkpoint_path_best_acc = os.path.join(args.results_dir,
                                            'model_best_acc.pth')

    rot_network.eval()
    for param in rot_network.parameters():
        param.requires_grad = False

    #########
    test_acc_max = -math.inf
    loop_start_time = time.time()
    checkpoint = {}
    for epoch in range(args.epochs):
        train(args, rot_network, class_network, train_loader, optimizer, mult,
              scheduler, epoch, in_features)

        train_loss, train_acc = test(args, rot_network, class_network,
                                     train_loader, mult, 'Train', in_features)
        results_dict['train_loss_hist'].append(train_loss)
        results_dict['train_acc_hist'].append(train_acc)

        test_loss, test_acc = test(args, rot_network, class_network,
                                   test_loader, mult, 'Test', in_features)
        results_dict['test_loss_hist'].append(test_loss)
        results_dict['test_acc_hist'].append(test_acc)
        print(
            'Epoch {} finished --------------------------------------------------------------------------'
            .format(epoch + 1))

        checkpoint = {
            'model_state_dict': class_network.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'test_loss': test_loss,
            'test_acc': test_acc
        }

        if test_acc > test_acc_max:
            test_acc_max = test_acc
            if os.path.isfile(checkpoint_path_best_acc):
                os.remove(checkpoint_path_best_acc)

            torch.save(checkpoint, checkpoint_path_best_acc)

            results_dict['best_acc_epoch'] = epoch + 1
            results_dict['train_loss'] = train_loss
            results_dict['train_acc'] = train_acc
            results_dict['test_loss'] = test_loss
            results_dict['test_acc'] = test_acc

    torch.save(checkpoint, checkpoint_path)

    print('Total time for training loop = ', time.time() - loop_start_time)

    return results_dict
Ejemplo n.º 3
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    model = models.ConvClassifier()

    if args.cpu:
        model.cpu()
    else:
        model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # normalize = torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    # 验证集占比
    val_ratio = 0.1

    train_loader = torch.utils.data.DataLoader(dataset.MajData(
        hashCriteria=lambda x: ((x % 23333) / 23333 > val_ratio)),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(dataset.MajData(
        hashCriteria=lambda x: ((x % 23333) / 23333 <= val_ratio)),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss()
    if args.cpu:
        criterion = criterion.cpu()
    else:
        criterion = criterion.cuda()

    if args.half:
        model.half()
        criterion.half()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename=os.path.join(args.save_dir,
                                  'checkpoint_{}.tar'.format(epoch)))
Ejemplo n.º 4
0
    trainData    = utils.DataNoNumpy(trainDataPath, trainLabelsPath)
    trainDataGen = DataLoader(trainData, batch_size=args.batchSize, num_workers=args.num_workers, shuffle=False)

   
    
    testData  = utils.DataNoNumpy(testDataPath, testLabelsPath)

    testDataGen = DataLoader(testData, batch_size=args.batchSize, num_workers=args.num_workers, shuffle=False)
    max_batch = len(trainDataGen)
    
    if args.archType == 1:
        model = models.ClassifierLayer(trainData.inputDim(), int(args.classNum), hiddenSize=args.hiddenUnits) 
    elif args.archType == 2: 
        model = models.PoolingClassifier(trainData.inputDim(), int(args.classNum), afterPoolingSize =args.afterPoolingSize, hiddenLayerSize = args.hiddenLayerSize)
    elif args.archType == 3:
        model = models.ConvClassifier(trainData.inputDim(), int(args.classNum), kernel_size=args.kernelSize, stride=args.stride)
    model.to(device)

    if args.resume:
        if device.type =="cpu": 
            checkpoint_ = torch.load(args.resume, map_location = device.type)
        else:
            checkpoint_ = torch.load(args.resume, map_location = device.type + ":" + str(device.index))

        best_acc = checkpoint_["best_acc"]
        model.load_state_dict(checkpoint_['state_dict'])
        epoch = checkpoint_['epoch']
    
    
    criterion = nn.CrossEntropyLoss()
    # Get the base model outputs, used for knowledge distillation