Ejemplo n.º 1
0
 def fit(self, train_dataloader: DataLoader, test_dataloader: DataLoader,
         loss_fn, optimizer):
     train_loss, train_acc, test_loss, test_acc = [], [], [], []
     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
     print(device)
     # model = self.model
     # self.model.to(device)
     epochs = self.config.num_epochs
     for epoch_i in range(0, epochs):
         print('======== Epoch {:} / {:} ========'.format(
             epoch_i + 1, epochs))
         res_train = self.train_epoch(train_dataloader, optimizer, loss_fn,
                                      device)
         train_loss.append(res_train[0])
         train_acc.append(res_train[1])
         res_test = self.test_epoch(test_dataloader, loss_fn, device)
         test_loss.append(res_test[0])
         test_acc.append(res_test[1])
     fit_result = FitResult(epochs, train_loss, train_acc, test_loss,
                            test_acc)
     if self.checkpoint_file is not None:
         self.save_checkpoint(fit_result)
     # print(f"train loss: {train_loss}")
     # print(f"train acc: {train_acc}")
     # print(f"test loss: {test_loss}")
     # print(f"test acc: {test_acc}")
     return fit_result
Ejemplo n.º 2
0
    def fit(self,
            dl_train: DataLoader,
            dl_test: DataLoader,
            num_epochs,
            checkpoints: str = None,
            early_stopping: int = None,
            print_every=1,
            post_epoch_fn=None,
            **kw) -> FitResult:
        """
        Trains the model for multiple epochs with a given training set,
        and calculates validation loss over a given validation set.
        :param dl_train: Dataloader for the training set.
        :param dl_test: Dataloader for the test set.
        :param num_epochs: Number of epochs to train for.
        :param checkpoints: Whether to save model to file every time the
            test set accuracy improves. Should be a string containing a
            filename without extension.
        :param early_stopping: Whether to stop training early if there is no
            test loss improvement for this number of epochs.
        :param print_every: Print progress every this number of epochs.
        :param post_epoch_fn: A function to call after each epoch completes.
        :return: A FitResult object containing train and test losses per epoch.
        """
        actual_num_epochs = 0
        train_loss, train_acc, test_loss, test_acc = [], [], [], []

        best_acc = None
        epochs_without_improvement = 0

        checkpoint_filename = None
        if checkpoints is not None:
            checkpoint_filename = f'{checkpoints}.pt'
            Path(os.path.dirname(checkpoint_filename)).mkdir(exist_ok=True)
            if os.path.isfile(checkpoint_filename):
                print(f'*** Loading checkpoint file {checkpoint_filename}')
                saved_state = torch.load(checkpoint_filename,
                                         map_location=self.device)
                best_acc = saved_state.get('best_acc', best_acc)
                epochs_without_improvement =\
                    saved_state.get('ewi', epochs_without_improvement)
                self.model.load_state_dict(saved_state['model_state'])

        for epoch in range(num_epochs):
            save_checkpoint = False
            verbose = False  # pass this to train/test_epoch.
            if epoch % print_every == 0 or epoch == num_epochs - 1:
                verbose = True
            self._print(f'--- EPOCH {epoch+1}/{num_epochs} ---', verbose)

            train_result = self.train_epoch(dl_train, verbose=verbose, **kw)
            (loss, acc) = train_result
            train_loss += loss
            train_acc.append(acc)

            test_result = self.test_epoch(dl_test, verbose=verbose, **kw)
            (loss, acc) = test_result
            test_loss += loss

            actual_num_epochs += 1

            if checkpoints:
                if not best_acc:
                    best_acc = acc
                if acc > best_acc:
                    best_acc = acc
                    save_checkpoint = True

            if test_acc:
                if acc <= test_acc[-1]:
                    epochs_without_improvement += 1
                else:
                    epochs_without_improvement = 0

            test_acc.append(acc)

            if early_stopping:
                if epochs_without_improvement >= early_stopping:
                    break

            # Save model checkpoint if requested
            if save_checkpoint and checkpoint_filename is not None:
                saved_state = dict(best_acc=best_acc,
                                   ewi=epochs_without_improvement,
                                   model_state=self.model.state_dict())
                torch.save(saved_state, checkpoint_filename)
                print(f'*** Saved checkpoint {checkpoint_filename} '
                      f'at epoch {epoch+1}')

            if post_epoch_fn:
                post_epoch_fn(epoch, train_result, test_result, verbose)

        return FitResult(actual_num_epochs, train_loss, train_acc, test_loss,
                         test_acc)
Ejemplo n.º 3
0
    def fit(self,
            dl_train: DataLoader,
            dl_test: DataLoader,
            num_epochs,
            checkpoints: str = None,
            early_stopping: int = None,
            print_every=1,
            post_epoch_fn=None,
            **kw) -> FitResult:
        """
        Trains the model for multiple epochs with a given training set,
        and calculates validation loss over a given validation set.
        :param dl_train: Dataloader for the training set.
        :param dl_test: Dataloader for the test set.
        :param num_epochs: Number of epochs to train for.
        :param checkpoints: Whether to save model to file every time the
            test set accuracy improves. Should be a string containing a
            filename without extension.
        :param early_stopping: Whether to stop training early if there is no
            test loss improvement for this number of epochs.
        :param print_every: Print progress every this number of epochs.
        :param post_epoch_fn: A function to call after each epoch completes.
        :return: A FitResult object containing train and test losses per epoch.
        """
        actual_num_epochs = 0
        train_loss, train_acc, test_loss, test_acc = [], [], [], []

        best_acc = None
        epochs_without_improvement = 0

        checkpoint_filename = None
        if checkpoints is not None:
            checkpoint_filename = f'{checkpoints}.pt'
            Path(os.path.dirname(checkpoint_filename)).mkdir(exist_ok=True)
            full_path = os.path.realpath(__file__)
            path, filename = os.path.split(full_path)
            full_path = os.path.join(path, checkpoint_filename)

            if os.path.isfile(path + '//' + checkpoint_filename):
                checkpoint_filename = path + '//' + checkpoint_filename
                print(f'*** Loading checkpoint file {checkpoint_filename}')
                saved_state = torch.load(checkpoint_filename,
                                         map_location=self.device)
                best_acc = saved_state.get('best_acc', best_acc)
                epochs_without_improvement =\
                    saved_state.get('ewi', epochs_without_improvement)
                self.model.load_state_dict(saved_state['model_state'])

        for epoch in range(num_epochs):
            save_checkpoint = False
            verbose = False  # pass this to train/test_epoch.
            if epoch % print_every == 0 or epoch == num_epochs - 1:
                verbose = True
            self._print(f'--- EPOCH {epoch+1}/{num_epochs} ---', verbose)

            train_result = self.train_epoch(dl_train, verbose=verbose, **kw)
            (loss, acc, TP, TN, FP, FN, out, y) = train_result
            train_loss += loss
            train_acc.append(acc)
            tr_acc = round(acc, 2)
            tr_loss = loss[-1]
            test_result = self.test_epoch(dl_test, verbose=verbose, **kw)
            (loss, acc, TP, TN, FP, FN, out, y) = test_result
            te_acc = round(acc, 2)
            te_loss = loss[-1]
            test_loss += loss
            try:
                is_best = (te_acc > best_acc)
            except:
                is_best = False
            checkpoints_name_parsed = checkpoints.split("/")
            with open(
                    f"Execution_dump_kernel_{checkpoints_name_parsed[1]}.txt",
                    "a") as myfile:
                myfile.write(
                    f'{actual_num_epochs} \t {tr_acc} \t {te_acc} \t {tr_loss} \t {te_loss} \t {is_best} \n'
                )
            if True:  #is_best
                with open(
                        f"Results_raw_dump_kernel_{checkpoints_name_parsed[1]}.txt",
                        "a") as myfile:
                    myfile.write(
                        f'{actual_num_epochs} \t {te_acc} \t {te_loss} \t {is_best} \t {TP.item()} \t {TN.item()} \t {FP.item()} \t {FN.item()}\n'
                    )
            actual_num_epochs += 1

            if checkpoints:
                if not best_acc:
                    best_acc = acc
                if acc > best_acc:
                    best_acc = acc
                    save_checkpoint = True

            if test_acc:
                if acc <= test_acc[-1]:
                    epochs_without_improvement += 1
                else:
                    epochs_without_improvement = 0

            test_acc.append(acc)

            if early_stopping:
                if epochs_without_improvement >= early_stopping:
                    break

            # Save model checkpoint if requested
            if save_checkpoint and checkpoint_filename is not None:
                saved_state = dict(best_acc=best_acc,
                                   ewi=epochs_without_improvement,
                                   model_state=self.model.state_dict())
                torch.save(saved_state, checkpoint_filename)
                print(f'*** Saved checkpoint {checkpoint_filename} '
                      f'at epoch {epoch+1}')

            if post_epoch_fn:
                post_epoch_fn(epoch, train_result, test_result, verbose)

        return FitResult(actual_num_epochs, train_loss, train_acc, test_loss,
                         test_acc)
Ejemplo n.º 4
0
        else:
            ax.set_xlabel('Epoch #')
            ax.set_ylabel('Accuracy (%)')
        if legend:
            ax.legend()

    plt.show()
    # plt.savefig('resnet.png')


if __name__ == '__main__':

    fp = open("slurm-test-srmnet.out", 'rb')
    fr = FitResult(num_epochs=85,
                   train_acc=[],
                   train_loss=[],
                   test_acc=[],
                   test_loss=[])

    for line in fp.readlines():
        line = line.decode('utf-8')
        if 'Avg. Loss' in line:
            if 'train_' in line:
                tttt = float(line.split('Avg. Loss ')[1].split(',')[0])
                zzzz = float(line.split('Accuracy ')[1].split(')')[0])
                fr.train_loss.append(tttt)
                fr.train_acc.append(zzzz)
            if 'test_' in line:
                tttt = float(line.split('Avg. Loss ')[1].split(',')[0])
                zzzz = float(line.split('Accuracy ')[1].split(')')[0])
                fr.test_loss.append(tttt)
Ejemplo n.º 5
0
    def fit(self, dl_train:DataLoader, dl_test:DataLoader, checkpoints:str=None,
            early_stopping:int=None, print_every:int=1, post_epoch_fn=None, **kw) -> FitResult:
        """
        Trains the model for multiple epochs with a given training set,
        and calculates validation loss over a given validation set.
        :param dl_train: DataLoader for the training set.
        :param dl_test: DataLoader for the test set.
        :param checkpoints: Whether to save model to file every time the test set accuracy improves.
        Should be a string containing a filename without extension.
        :param early_stopping: Whether to stop training early if there is no
        test loss improvement for this number of epochs.
        :param print_every: Print progress every this number of epochs.
        :param post_epoch_fn: A function to call after each epoch completes.
        :return: A FitResult object containing train and test losses per epoch.
        """
        train_loss,test_loss = [],[]
        best_loss                  = None
        actual_num_epochs          = 0
        epochs_without_improvement = 0

        checkpoint_filename        = None
        if checkpoints is not None:
            checkpoint_filename = f'{checkpoints}.pt'
            Path(os.path.dirname(checkpoint_filename)).mkdir(exist_ok=True)
            if os.path.isfile(checkpoint_filename):
                print(f'[I] - Loading checkpoint file {checkpoint_filename}')
                saved_state = torch.load(checkpoint_filename, map_location=self.device)
                epochs_without_improvement = saved_state.get('ewi', epochs_without_improvement)
                self.model.net.load_state_dict(saved_state['model_state'])

        for epoch in range(self.num_epochs):
            save_checkpoint = False
            verbose         = False

            if epoch % print_every == 0 or epoch == self.num_epochs - 1:
                verbose = True

            self._print(f'--- EPOCH {epoch+1}/{self.num_epochs} ---', verbose) # Conditional verbose

            train_result = self.train_epoch(dl_train)
            test_result  = self.test_epoch(dl_test)
            train_loss.append(train_result.losses)
            test_loss.append(test_result.losses)
            actual_num_epochs += 1
            if best_loss is None:
                best_loss = np.mean(test_loss[-1])
            else:
                if best_loss > np.mean(test_loss[-1]):
                    best_loss = np.mean(test_loss[-1])
                    epochs_without_improvement = 0
                    save_checkpoint = True
                else: # Count the number of epochs without improvement for early stopping
                    epochs_without_improvement += 1

            # print(f'best_loss={best_loss} | epochs_without_improvement={epochs_without_improvement}')

            if epochs_without_improvement == early_stopping:
                break

            # Save model checkpoint if requested
            if save_checkpoint and checkpoint_filename is not None:
                saved_state = dict(best_loss=best_loss,
                                   ewi=epochs_without_improvement,
                                   model_state=self.model.net.state_dict())
                torch.save(saved_state, checkpoint_filename)
                print(f'[I] - Saved checkpoint {checkpoint_filename} at epoch {epoch+1}')

            if post_epoch_fn:
                # stop_training = post_epoch_fn(model=self.model, device=self.device, dl_test=dl_test)
                # if stop_training:
                #     print(f'[I] - Loss threshold achieved. stop training')
                #     break
                post_epoch_fn(model=self.model, device=self.device, dl_test=dl_test)

        return FitResult(actual_num_epochs, train_loss, test_loss)
Ejemplo n.º 6
0
    def fit(self, dl_train: DataLoader, dl_test: DataLoader,
            num_epochs, checkpoints: str = None,
            early_stopping: int = None,
            print_every=1, post_epoch_fn=None, **kw) -> FitResult:

        """
        Trains the model for multiple epochs with a given training set,
        and calculates validation loss over a given validation set.
        :param dl_train: Dataloader for the training set.
        :param dl_test: Dataloader for the test set.
        :param num_epochs: Number of epochs to train for.
        :param checkpoints: Whether to save model to file every time the
            test set accuracy improves. Should be a string containing a
            filename without extension.
        :param early_stopping: Whether to stop training early if there is no
            test loss improvement for this number of epochs.
        :param print_every: Print progress every this number of epochs.
        :param post_epoch_fn: A function to call after each epoch completes.
        :return: A FitResult object containing train and test losses per epoch.
        """

        actual_num_epochs = 0
        train_loss, train_acc, test_loss, test_acc = [], [], [], []

        best_acc = None
        epochs_without_improvement = 0

        epochs_so_far = 0
        checkpoint_filename=None
        if checkpoints is not None:
            checkpoint_filename = f'{checkpoints}.pt'
            Path(os.path.dirname(checkpoint_filename)).mkdir(exist_ok=True)
            if os.path.isfile(checkpoint_filename):
                print(f'*** Loading checkpoint file {checkpoint_filename}')
                saved_state = torch.load(checkpoint_filename,
                                         map_location=self.device)
                best_acc = saved_state.get('best_acc', best_acc)
                epochs_without_improvement =\
                    saved_state.get('ewi', epochs_without_improvement)
                epochs_so_far = saved_state.get('esf', epochs_so_far)
                actual_num_epochs += epochs_so_far
                self.model.load_state_dict(saved_state['model_state'])
                fit_res = saved_state['fit_result']
                train_loss, train_acc, test_loss, test_acc = \
                    fit_res[1], fit_res[2], fit_res[3], fit_res[4]

        if epochs_so_far == num_epochs:
            batches = kw.get("max_batches") if "max_batches" in kw else None
            test_res = self.test_epoch(dl_test, max_batches=batches)
            test_loss.append(sum(test_res.losses) / len(test_res.losses))
            test_acc.append(test_res.accuracy)

        for epoch in range(epochs_so_far, num_epochs):
            save_checkpoint = True
            verbose = False  # pass this to train/test_epoch.
            if epoch % print_every == 0 or epoch == num_epochs - 1:
                verbose = True
            self._print(f'--- EPOCH {epoch+1}/{num_epochs} ---', verbose)

            #  Train & evaluate for one epoch:
            # - Use the train/test_epoch methods.
            # - Save losses and accuracies in the lists above.
            # - Implement early stopping. This is a very useful and
            #   simple regularization technique that is highly recommended.
            batches = None
            if "max_batches" in kw:
                batches = kw.get("max_batches")

            actual_num_epochs += 1
            train_res = self.train_epoch(dl_train, verbose=verbose, max_batches=batches)
            test_res = self.test_epoch(dl_test, verbose=verbose, max_batches=batches)
            train_loss.append(sum(train_res.losses) / len(train_res.losses))
            train_acc.append(train_res.accuracy)
            test_loss.append(sum(test_res.losses) / len(test_res.losses))
            test_acc.append(test_res.accuracy)
            if early_stopping is not None and len(test_loss) >= 2:
                if test_loss[-1] >= test_loss[-2]:
                    epochs_without_improvement += 1
                    if epochs_without_improvement == early_stopping:
                        break
                else:
                    epochs_without_improvement = 0

            best_acc = max(best_acc if best_acc is not None else 0, test_res.accuracy)

            # Save model checkpoint if requested
            if save_checkpoint and checkpoint_filename is not None:
                saved_state = dict(best_acc=best_acc,
                                   ewi=epochs_without_improvement,
                                   model_state=self.model.state_dict(),
                                   esf=epoch+1,
                                   fit_result=FitResult(actual_num_epochs, train_loss, train_acc, test_loss, test_acc))

                torch.save(saved_state, checkpoint_filename)
                print(f'*** Saved checkpoint {checkpoint_filename} '
                      f'at epoch {epoch+1}')

            if post_epoch_fn:
                post_epoch_fn(epoch, train_res, test_res, verbose)

        return FitResult(actual_num_epochs,
                         train_loss, train_acc, test_loss, test_acc)