예제 #1
0
def run(num_epochs=5, device="cuda"):
    criterion = nn.CrossEntropyLoss()

    augmenter = Augmenter(approximate=True)

    full_dataset = DBDataset(TrainImage, MathSymbol, get_data(augmenter),
                             get_label, get_class_name)

    full_dataset = BalancedDS(full_dataset)

    test_train_split = 0.9
    train_size = int(test_train_split * len(full_dataset))
    test_size = len(full_dataset) - train_size

    train_dataset, test_dataset = random_split(full_dataset,
                                               [train_size, test_size])

    dataloaders = {
        "train": DataLoader(train_dataset, batch_size=32, num_workers=1),
        "test":  DataLoader(test_dataset, batch_size=16, shuffle=True,
                            num_workers=1)
    }
    dataset_sizes = {
        "train": len(train_dataset),
        "test": len(test_dataset)
    }

    print(dataloaders)
    print(dataset_sizes)

    model = MobileNet(features=len(full_dataset.classes), pretrained=True)
    model.freeze()
    model = model.to(device)

    def unfreeze(x, y):
        model.unfreeze()
        augmenter.approximate = False

    solver = Solver(criterion, dataloaders, model, cb=unfreeze)
    model, accuracy = solver.train(device=device,
                                   num_epochs=num_epochs)

    eval_model(model, dataloaders["test"], device, len(full_dataset.classes))

    model = model.to('cpu')
    torch.save(model.state_dict(), "test_augment.pth")

    model.estimate_variane = True

    byteArr = model.to_onnx()

    torchByteArr = io.BytesIO()
    torch.save(model.state_dict(), torchByteArr)

    model_entity = ClassificationModel(None, model=byteArr.getvalue(),
                                       timestamp=timezone.now(),
                                       pytorch=torchByteArr.getvalue(),
                                       accuracy=accuracy)
    model_entity.save()
예제 #2
0
def train_classifier(train_batch_size=16,
                     test_batch_size=4,
                     device="cpu",
                     num_epochs=5):

    criterion = nn.CrossEntropyLoss()

    dataloaders, full_dataset = setup_db_dl(train_batch_size, test_batch_size,
                                            get_data)

    print(dataloaders)

    old_model = ClassificationModel.get_latest().to_pytorch()

    n_features = full_dataset.get_input_shape()[0]
    n_classes = full_dataset.num_classes

    model = LinearModel(n_features, n_classes)

    model = model.to(device)

    solver = Solver(criterion, dataloaders, model)
    model, accuracy = solver.train(device=device,
                                   num_epochs=num_epochs,
                                   step_size=2)

    eval_model(model, dataloaders["test"], device, len(full_dataset.classes))

    model = model.to('cpu')
    old_model.set_classifier(model.classifier)
    old_model = old_model.eval()

    del dataloaders
    del full_dataset
    gc.collect()

    byteArr = io.BytesIO()
    dummy_input = torch.randn(1, 3, 224, 224)
    torch.onnx.export(old_model, dummy_input, byteArr)

    torchByteArr = io.BytesIO()
    torch.save(old_model.state_dict(), torchByteArr)

    model_entity = ClassificationModel(None, model=byteArr.getvalue(),
                                       timestamp=timezone.now(),
                                       pytorch=torchByteArr.getvalue(),
                                       accuracy=accuracy)
    model_entity.save()
예제 #3
0
def run():
    path = 'test_augment.pth'

    pytorch = Path(path).read_bytes()

    model = MobileNet.from_file(path)
    model.eval()

    byte_arr = model.to_onnx()

    model_instance = ClassificationModel(None,
                                         model=byte_arr.getvalue(),
                                         pytorch=pytorch,
                                         timestamp=timezone.now(),
                                         accuracy=0.99)
    model_instance.save()
예제 #4
0
파일: test.py 프로젝트: Hoff97/detext
def run():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    names = [ms.name for ms in MathSymbol.objects.all().order_by('timestamp')]
    print(names)

    data_dir = 'res/test'
    full_dataset = datasets.ImageFolder(data_dir, mm.preprocess)
    dataloader = DataLoader(full_dataset,
                            batch_size=1,
                            shuffle=True,
                            num_workers=4)

    model = ClassificationModel.get_latest().to_pytorch()

    model.eval()
    model = model.to(device)

    correct = 0

    for i, data in enumerate(dataloader):
        inputs, labels = data
        inputs = inputs.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        actual_name, pred_name = full_dataset.classes[labels], names[preds]
        if actual_name == pred_name:
            correct += 1
        else:
            print(actual_name, pred_name)

    print(f'Correct: {correct}/{len(dataloader)}')
예제 #5
0
def run():
    file_name = "test_augment.pth"

    model = MobileNet.from_file(file_name,
                                test_time_dropout=False,
                                estimate_variane=True)

    byteArr = model.to_onnx()

    torchByteArr = io.BytesIO()
    torch.save(model.state_dict(), torchByteArr)

    model_entity = ClassificationModel(None,
                                       model=byteArr.getvalue(),
                                       timestamp=timezone.now(),
                                       pytorch=torchByteArr.getvalue(),
                                       accuracy=0.99)
    model_entity.save()
예제 #6
0
    def create(self, request):
        if request.user.id is None:
            raise PermissionDenied(
                {"message": "Can only create model as root"})

        pytorch = request.data['pytorch']
        pytorch = base64.decodebytes(pytorch.encode())

        onnx = request.data['onnx']
        onnx = base64.decodebytes(onnx.encode())

        model_instance = ClassificationModel(None,
                                             model=onnx,
                                             pytorch=pytorch,
                                             timestamp=timezone.now(),
                                             accuracy=0.99)
        model_instance.save()

        return Response('Ok')
예제 #7
0
def run(num_epochs=5, device="cpu"):
    criterion = nn.CrossEntropyLoss()

    dataloaders, full_dataset = setup_db_dl()

    print(dataloaders)

    model = mm.MobileNet(features=len(full_dataset.classes), pretrained=False)
    model = model.to(device)

    solver = Solver(criterion, dataloaders, model)
    model, accuracy = solver.train(device=device, num_epochs=num_epochs)

    byteArr = model.to_onnx()

    torchByteArr = io.BytesIO()
    torch.save(model.state_dict(), torchByteArr)

    model_entity = ClassificationModel(None,
                                       model=byteArr.getvalue(),
                                       timestamp=timezone.now(),
                                       pytorch=torchByteArr.getvalue(),
                                       accuracy=accuracy)
    model_entity.save()
예제 #8
0
    def update_features(self, train_image, img):
        with torch.no_grad():
            model = ClassificationModel.get_latest().to_pytorch()

            img = mm.preprocess(img)
            img = img.repeat((3, 1, 1))
            img = img.reshape((1, img.shape[0], img.shape[1], img.shape[2]))

            features = model.features(img)
            features = features.mean([2, 3])
            byte_f = io.BytesIO()
            torch.save(features, byte_f)

            train_image.features = byte_f.getvalue()
            train_image.save()
예제 #9
0
def run():
    ort_session = ClassificationModel.get_latest().to_onnx()

    data_dir = 'res/test'
    full_dataset = datasets.ImageFolder(data_dir, mm.preprocess)
    dataloader = DataLoader(full_dataset,
                            batch_size=1,
                            shuffle=True,
                            num_workers=4)

    for i, data in enumerate(dataloader):
        inputs, labels = data
        inputs = inputs.numpy()

        outputs = ort_session.run(None, {'input.1': inputs})
        _, preds = torch.max(torch.from_numpy(np.array(outputs)), 2)
        print(preds, full_dataset.classes[labels], outputs)
예제 #10
0
 def test_most_recent_looks_at_timestamp(self):
     latest = ClassificationModel.get_latest()
     formatted = latest.timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")
     formatted = f'{formatted}Z'
     response = self.client.get(f'/api/model/latest/?timestamp={formatted}')
     self.assertEquals(response.status_code, status.HTTP_304_NOT_MODIFIED)