예제 #1
0
def run(num_epochs=5, device="cuda"):
    criterion = nn.CrossEntropyLoss()

    augmenter = Augmenter(approximate=True)

    full_dataset = DBDataset(TrainImage, MathSymbol, get_data(augmenter),
                             get_label, get_class_name)

    full_dataset = BalancedDS(full_dataset)

    test_train_split = 0.9
    train_size = int(test_train_split * len(full_dataset))
    test_size = len(full_dataset) - train_size

    train_dataset, test_dataset = random_split(full_dataset,
                                               [train_size, test_size])

    dataloaders = {
        "train": DataLoader(train_dataset, batch_size=32, num_workers=1),
        "test":  DataLoader(test_dataset, batch_size=16, shuffle=True,
                            num_workers=1)
    }
    dataset_sizes = {
        "train": len(train_dataset),
        "test": len(test_dataset)
    }

    print(dataloaders)
    print(dataset_sizes)

    model = MobileNet(features=len(full_dataset.classes), pretrained=True)
    model.freeze()
    model = model.to(device)

    def unfreeze(x, y):
        model.unfreeze()
        augmenter.approximate = False

    solver = Solver(criterion, dataloaders, model, cb=unfreeze)
    model, accuracy = solver.train(device=device,
                                   num_epochs=num_epochs)

    eval_model(model, dataloaders["test"], device, len(full_dataset.classes))

    model = model.to('cpu')
    torch.save(model.state_dict(), "test_augment.pth")

    model.estimate_variane = True

    byteArr = model.to_onnx()

    torchByteArr = io.BytesIO()
    torch.save(model.state_dict(), torchByteArr)

    model_entity = ClassificationModel(None, model=byteArr.getvalue(),
                                       timestamp=timezone.now(),
                                       pytorch=torchByteArr.getvalue(),
                                       accuracy=accuracy)
    model_entity.save()
예제 #2
0
def train_classifier(train_batch_size=16,
                     test_batch_size=4,
                     device="cpu",
                     num_epochs=5):

    criterion = nn.CrossEntropyLoss()

    dataloaders, full_dataset = setup_db_dl(train_batch_size, test_batch_size,
                                            get_data)

    print(dataloaders)

    old_model = ClassificationModel.get_latest().to_pytorch()

    n_features = full_dataset.get_input_shape()[0]
    n_classes = full_dataset.num_classes

    model = LinearModel(n_features, n_classes)

    model = model.to(device)

    solver = Solver(criterion, dataloaders, model)
    model, accuracy = solver.train(device=device,
                                   num_epochs=num_epochs,
                                   step_size=2)

    eval_model(model, dataloaders["test"], device, len(full_dataset.classes))

    model = model.to('cpu')
    old_model.set_classifier(model.classifier)
    old_model = old_model.eval()

    del dataloaders
    del full_dataset
    gc.collect()

    byteArr = io.BytesIO()
    dummy_input = torch.randn(1, 3, 224, 224)
    torch.onnx.export(old_model, dummy_input, byteArr)

    torchByteArr = io.BytesIO()
    torch.save(old_model.state_dict(), torchByteArr)

    model_entity = ClassificationModel(None, model=byteArr.getvalue(),
                                       timestamp=timezone.now(),
                                       pytorch=torchByteArr.getvalue(),
                                       accuracy=accuracy)
    model_entity.save()
예제 #3
0
def run():
    path = 'test_augment.pth'

    pytorch = Path(path).read_bytes()

    model = MobileNet.from_file(path)
    model.eval()

    byte_arr = model.to_onnx()

    model_instance = ClassificationModel(None,
                                         model=byte_arr.getvalue(),
                                         pytorch=pytorch,
                                         timestamp=timezone.now(),
                                         accuracy=0.99)
    model_instance.save()
예제 #4
0
def run():
    file_name = "test_augment.pth"

    model = MobileNet.from_file(file_name,
                                test_time_dropout=False,
                                estimate_variane=True)

    byteArr = model.to_onnx()

    torchByteArr = io.BytesIO()
    torch.save(model.state_dict(), torchByteArr)

    model_entity = ClassificationModel(None,
                                       model=byteArr.getvalue(),
                                       timestamp=timezone.now(),
                                       pytorch=torchByteArr.getvalue(),
                                       accuracy=0.99)
    model_entity.save()
예제 #5
0
    def create(self, request):
        if request.user.id is None:
            raise PermissionDenied(
                {"message": "Can only create model as root"})

        pytorch = request.data['pytorch']
        pytorch = base64.decodebytes(pytorch.encode())

        onnx = request.data['onnx']
        onnx = base64.decodebytes(onnx.encode())

        model_instance = ClassificationModel(None,
                                             model=onnx,
                                             pytorch=pytorch,
                                             timestamp=timezone.now(),
                                             accuracy=0.99)
        model_instance.save()

        return Response('Ok')
예제 #6
0
def run(num_epochs=5, device="cpu"):
    criterion = nn.CrossEntropyLoss()

    dataloaders, full_dataset = setup_db_dl()

    print(dataloaders)

    model = mm.MobileNet(features=len(full_dataset.classes), pretrained=False)
    model = model.to(device)

    solver = Solver(criterion, dataloaders, model)
    model, accuracy = solver.train(device=device, num_epochs=num_epochs)

    byteArr = model.to_onnx()

    torchByteArr = io.BytesIO()
    torch.save(model.state_dict(), torchByteArr)

    model_entity = ClassificationModel(None,
                                       model=byteArr.getvalue(),
                                       timestamp=timezone.now(),
                                       pytorch=torchByteArr.getvalue(),
                                       accuracy=accuracy)
    model_entity.save()