class PNNM(OWWidget): name = "Pytorch CNN" description = "" # icon = "icons/robot.svg" want_main_area = True class Inputs: data = Input('Data', ImageDataBunch, default=True) def __init__(self): super().__init__() self.learn = None # train_button = gui.button(self.controlArea, self, "开始训练", callback=self.train) self.label = gui.label(self.mainArea, self, "模型结构") #: The current evaluating task (if any) self._task = None # type: Optional[Task] #: An executor we use to submit learner evaluations into a thread pool self._executor = ThreadExecutor() # Device configuration self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # Hyper parameters num_epochs = 5 num_classes = 10 batch_size = 100 learning_rate = 0.001 dir_path = Path(__file__).resolve() data_path = f'{dir_path.parent.parent.parent}/datasets/' # MNIST dataset self.train_dataset = torchvision.datasets.MNIST(root=data_path, train=True, transform=transforms.ToTensor(), download=False) self.test_dataset = torchvision.datasets.MNIST(root=data_path, train=False, transform=transforms.ToTensor()) # Data loader self.train_loader = torch.utils.data.DataLoader(dataset=self.train_dataset, batch_size=batch_size, shuffle=False) self.test_loader = torch.utils.data.DataLoader(dataset=self.test_dataset, batch_size=batch_size, shuffle=False) # self.model = ConvNet(num_classes).to(self.device) self.model = nn.Sequential( self.conv(1, 8), # 14 nn.BatchNorm2d(8), nn.ReLU(), self.conv(8, 16), # 7 nn.BatchNorm2d(16), nn.ReLU(), self.conv(16, 32), # 4 nn.BatchNorm2d(32), nn.ReLU(), self.conv(32, 16), # 2 nn.BatchNorm2d(16), nn.ReLU(), self.conv(16, 10), # 1 nn.BatchNorm2d(10), Flatten() # remove (1,1) grid ).to(self.device) # Loss and optimizer self.criterion = nn.CrossEntropyLoss() self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate) def handleNewSignals(self): self._update() def _update(self): if self._task is not None: # First make sure any pending tasks are cancelled. self.cancel() assert self._task is None if self.data is None: return # collect all learners for which results have not yet been computed if not self.learn: return # setup the task state self._task = task = Task() # The learning_curve[_with_test_data] also takes a callback function # to report the progress. We instrument this callback to both invoke # the appropriate slots on this widget for reporting the progress # (in a thread safe manner) and to implement cooperative cancellation. set_progress = methodinvoke(self, "setProgressValue", (float,)) def callback(finished): # check if the task has been cancelled and raise an exception # from within. This 'strategy' can only be used with code that # properly cleans up after itself in the case of an exception # (does not leave any global locks, opened file descriptors, ...) if task.cancelled: raise KeyboardInterrupt() set_progress(finished * 100) self.progressBarInit() # Submit the evaluation function to the executor and fill in the # task with the resultant Future. # task.future = self._executor.submit(self.learn.fit_one_cycle(1)) fit_model = partial(train_model, self.model, 5, self.train_loader, self.test_loader, self.device, self.criterion, self.optimizer, callback=callback) task.future = self._executor.submit(fit_model) # Setup the FutureWatcher to notify us of completion task.watcher = FutureWatcher(task.future) # by using FutureWatcher we ensure `_task_finished` slot will be # called from the main GUI thread by the Qt's event loop task.watcher.done.connect(self._task_finished) @pyqtSlot(float) def setProgressValue(self, value): assert self.thread() is QThread.currentThread() self.progressBarSet(value) @pyqtSlot(concurrent.futures.Future) def _task_finished(self, f): """ Parameters ---------- f : Future The future instance holding the result of learner evaluation. """ assert self.thread() is QThread.currentThread() assert self._task is not None assert self._task.future is f assert f.done() self._task = None self.progressBarFinished() self.model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance) with torch.no_grad(): correct = 0 total = 0 for images, labels in self.test_loader: images = images.to(self.device) labels = labels.to(self.device) outputs = self.model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) # try: # result = f.result() # type: List[Results] # except Exception as ex: # # Log the exception with a traceback # log = logging.getLogger() # log.exception(__name__, exc_info=True) # self.error("Exception occurred during evaluation: {!r}".format(ex)) # # clear all results # self.result= None # else: print(self.learn.validate()) # ... and update self.results def cancel(self): """ Cancel the current task (if any). """ if self._task is not None: self._task.cancel() assert self._task.future.done() # disconnect the `_task_finished` slot self._task.watcher.done.disconnect(self._task_finished) self._task = None def onDeleteWidget(self): self.cancel() super().onDeleteWidget() def conv(self, ni, nf): return nn.Conv2d(ni, nf, kernel_size=3, stride=2, padding=1) def train(self): if self.learn is None: return self.learn.fit_one_cycle(3) @Inputs.data def set_data(self, data): if data is not None: self.data = data self.learn = Learner(self.data, self.model, loss_func=nn.CrossEntropyLoss(), metrics=accuracy, add_time=False, bn_wd=False, silent=True) self.label.setText(self.learn.summary()) else: self.data = None
class CNNM(OWWidget): name = "M CNN" description = "" # icon = "icons/robot.svg" want_main_area = True class Inputs: data = Input('Data', ImageDataBunch, default=True) def __init__(self): super().__init__() self.learn = None # train_button = gui.button(self.controlArea, self, "开始训练", callback=self.train) self.label = gui.label(self.mainArea, self, "模型结构") #: The current evaluating task (if any) self._task = None # type: Optional[Task] #: An executor we use to submit learner evaluations into a thread pool self._executor = ThreadExecutor() self.model = nn.Sequential( self.conv(1, 8), # 14 nn.BatchNorm2d(8), nn.ReLU(), self.conv(8, 16), # 7 nn.BatchNorm2d(16), nn.ReLU(), self.conv(16, 32), # 4 nn.BatchNorm2d(32), nn.ReLU(), self.conv(32, 16), # 2 nn.BatchNorm2d(16), nn.ReLU(), self.conv(16, 10), # 1 nn.BatchNorm2d(10), Flatten() # remove (1,1) grid ) def handleNewSignals(self): self._update() def _update(self): if self._task is not None: # First make sure any pending tasks are cancelled. self.cancel() assert self._task is None if self.data is None: return # collect all learners for which results have not yet been computed if not self.learn: return # setup the task state self._task = task = Task() # The learning_curve[_with_test_data] also takes a callback function # to report the progress. We instrument this callback to both invoke # the appropriate slots on this widget for reporting the progress # (in a thread safe manner) and to implement cooperative cancellation. set_progress = methodinvoke(self, "setProgressValue", (float,)) def callback(finished): # check if the task has been cancelled and raise an exception # from within. This 'strategy' can only be used with code that # properly cleans up after itself in the case of an exception # (does not leave any global locks, opened file descriptors, ...) if task.cancelled: raise KeyboardInterrupt() set_progress(finished * 100) self.progressBarInit() # Submit the evaluation function to the executor and fill in the # task with the resultant Future. # task.future = self._executor.submit(self.learn.fit_one_cycle(1)) with progress_disabled_ctx(self.learn) as learn: fit_model = partial(my_fit, learn, 1, callback=callback) task.future = self._executor.submit(fit_model) # Setup the FutureWatcher to notify us of completion task.watcher = FutureWatcher(task.future) # by using FutureWatcher we ensure `_task_finished` slot will be # called from the main GUI thread by the Qt's event loop task.watcher.done.connect(self._task_finished) @pyqtSlot(float) def setProgressValue(self, value): assert self.thread() is QThread.currentThread() self.progressBarSet(value) @pyqtSlot(concurrent.futures.Future) def _task_finished(self, f): """ Parameters ---------- f : Future The future instance holding the result of learner evaluation. """ assert self.thread() is QThread.currentThread() assert self._task is not None assert self._task.future is f assert f.done() self._task = None self.progressBarFinished() # try: # result = f.result() # type: List[Results] # except Exception as ex: # # Log the exception with a traceback # log = logging.getLogger() # log.exception(__name__, exc_info=True) # self.error("Exception occurred during evaluation: {!r}".format(ex)) # # clear all results # self.result= None # else: print(self.learn.validate()) # ... and update self.results def cancel(self): """ Cancel the current task (if any). """ if self._task is not None: self._task.cancel() assert self._task.future.done() # disconnect the `_task_finished` slot self._task.watcher.done.disconnect(self._task_finished) self._task = None def onDeleteWidget(self): self.cancel() super().onDeleteWidget() def conv(self, ni, nf): return nn.Conv2d(ni, nf, kernel_size=3, stride=2, padding=1) def train(self): if self.learn is None: return self.learn.fit_one_cycle(3) @Inputs.data def set_data(self, data): if data is not None: self.data = data self.learn = Learner(self.data, self.model, loss_func=nn.CrossEntropyLoss(), metrics=accuracy, add_time=False, bn_wd=False, silent=True) self.label.setText(self.learn.summary()) else: self.data = None