def run(num_epochs=5, device="cuda"): criterion = nn.CrossEntropyLoss() augmenter = Augmenter(approximate=True) full_dataset = DBDataset(TrainImage, MathSymbol, get_data(augmenter), get_label, get_class_name) full_dataset = BalancedDS(full_dataset) test_train_split = 0.9 train_size = int(test_train_split * len(full_dataset)) test_size = len(full_dataset) - train_size train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size]) dataloaders = { "train": DataLoader(train_dataset, batch_size=32, num_workers=1), "test": DataLoader(test_dataset, batch_size=16, shuffle=True, num_workers=1) } dataset_sizes = { "train": len(train_dataset), "test": len(test_dataset) } print(dataloaders) print(dataset_sizes) model = MobileNet(features=len(full_dataset.classes), pretrained=True) model.freeze() model = model.to(device) def unfreeze(x, y): model.unfreeze() augmenter.approximate = False solver = Solver(criterion, dataloaders, model, cb=unfreeze) model, accuracy = solver.train(device=device, num_epochs=num_epochs) eval_model(model, dataloaders["test"], device, len(full_dataset.classes)) model = model.to('cpu') torch.save(model.state_dict(), "test_augment.pth") model.estimate_variane = True byteArr = model.to_onnx() torchByteArr = io.BytesIO() torch.save(model.state_dict(), torchByteArr) model_entity = ClassificationModel(None, model=byteArr.getvalue(), timestamp=timezone.now(), pytorch=torchByteArr.getvalue(), accuracy=accuracy) model_entity.save()
def train_classifier(train_batch_size=16, test_batch_size=4, device="cpu", num_epochs=5): criterion = nn.CrossEntropyLoss() dataloaders, full_dataset = setup_db_dl(train_batch_size, test_batch_size, get_data) print(dataloaders) old_model = ClassificationModel.get_latest().to_pytorch() n_features = full_dataset.get_input_shape()[0] n_classes = full_dataset.num_classes model = LinearModel(n_features, n_classes) model = model.to(device) solver = Solver(criterion, dataloaders, model) model, accuracy = solver.train(device=device, num_epochs=num_epochs, step_size=2) eval_model(model, dataloaders["test"], device, len(full_dataset.classes)) model = model.to('cpu') old_model.set_classifier(model.classifier) old_model = old_model.eval() del dataloaders del full_dataset gc.collect() byteArr = io.BytesIO() dummy_input = torch.randn(1, 3, 224, 224) torch.onnx.export(old_model, dummy_input, byteArr) torchByteArr = io.BytesIO() torch.save(old_model.state_dict(), torchByteArr) model_entity = ClassificationModel(None, model=byteArr.getvalue(), timestamp=timezone.now(), pytorch=torchByteArr.getvalue(), accuracy=accuracy) model_entity.save()
def run(): path = 'test_augment.pth' pytorch = Path(path).read_bytes() model = MobileNet.from_file(path) model.eval() byte_arr = model.to_onnx() model_instance = ClassificationModel(None, model=byte_arr.getvalue(), pytorch=pytorch, timestamp=timezone.now(), accuracy=0.99) model_instance.save()
def run(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") names = [ms.name for ms in MathSymbol.objects.all().order_by('timestamp')] print(names) data_dir = 'res/test' full_dataset = datasets.ImageFolder(data_dir, mm.preprocess) dataloader = DataLoader(full_dataset, batch_size=1, shuffle=True, num_workers=4) model = ClassificationModel.get_latest().to_pytorch() model.eval() model = model.to(device) correct = 0 for i, data in enumerate(dataloader): inputs, labels = data inputs = inputs.to(device) outputs = model(inputs) _, preds = torch.max(outputs, 1) actual_name, pred_name = full_dataset.classes[labels], names[preds] if actual_name == pred_name: correct += 1 else: print(actual_name, pred_name) print(f'Correct: {correct}/{len(dataloader)}')
def run(): file_name = "test_augment.pth" model = MobileNet.from_file(file_name, test_time_dropout=False, estimate_variane=True) byteArr = model.to_onnx() torchByteArr = io.BytesIO() torch.save(model.state_dict(), torchByteArr) model_entity = ClassificationModel(None, model=byteArr.getvalue(), timestamp=timezone.now(), pytorch=torchByteArr.getvalue(), accuracy=0.99) model_entity.save()
def create(self, request): if request.user.id is None: raise PermissionDenied( {"message": "Can only create model as root"}) pytorch = request.data['pytorch'] pytorch = base64.decodebytes(pytorch.encode()) onnx = request.data['onnx'] onnx = base64.decodebytes(onnx.encode()) model_instance = ClassificationModel(None, model=onnx, pytorch=pytorch, timestamp=timezone.now(), accuracy=0.99) model_instance.save() return Response('Ok')
def run(num_epochs=5, device="cpu"): criterion = nn.CrossEntropyLoss() dataloaders, full_dataset = setup_db_dl() print(dataloaders) model = mm.MobileNet(features=len(full_dataset.classes), pretrained=False) model = model.to(device) solver = Solver(criterion, dataloaders, model) model, accuracy = solver.train(device=device, num_epochs=num_epochs) byteArr = model.to_onnx() torchByteArr = io.BytesIO() torch.save(model.state_dict(), torchByteArr) model_entity = ClassificationModel(None, model=byteArr.getvalue(), timestamp=timezone.now(), pytorch=torchByteArr.getvalue(), accuracy=accuracy) model_entity.save()
def update_features(self, train_image, img): with torch.no_grad(): model = ClassificationModel.get_latest().to_pytorch() img = mm.preprocess(img) img = img.repeat((3, 1, 1)) img = img.reshape((1, img.shape[0], img.shape[1], img.shape[2])) features = model.features(img) features = features.mean([2, 3]) byte_f = io.BytesIO() torch.save(features, byte_f) train_image.features = byte_f.getvalue() train_image.save()
def run(): ort_session = ClassificationModel.get_latest().to_onnx() data_dir = 'res/test' full_dataset = datasets.ImageFolder(data_dir, mm.preprocess) dataloader = DataLoader(full_dataset, batch_size=1, shuffle=True, num_workers=4) for i, data in enumerate(dataloader): inputs, labels = data inputs = inputs.numpy() outputs = ort_session.run(None, {'input.1': inputs}) _, preds = torch.max(torch.from_numpy(np.array(outputs)), 2) print(preds, full_dataset.classes[labels], outputs)
def test_most_recent_looks_at_timestamp(self): latest = ClassificationModel.get_latest() formatted = latest.timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f") formatted = f'{formatted}Z' response = self.client.get(f'/api/model/latest/?timestamp={formatted}') self.assertEquals(response.status_code, status.HTTP_304_NOT_MODIFIED)