예제 #1
0
def forward_pytorch(weightfile, image):
    #net=resnet.resnet18()
    #net = resnet.resnet18()
    net=LeNet(1,2)
    checkpoint = torch.load(weightfile)
    net.load_state_dict(checkpoint['weight'])
    net.double()                # to double
    if args.cuda:
        net.cuda()
    print(net)
    net.eval()
    image = torch.from_numpy(image.astype(np.float64)) # to double

    if args.cuda:
        image = Variable(image.cuda())
    else:
        image = Variable(image)
    t0 = time.time()
    blobs = net.forward(image)
    print(blobs.data.numpy().flatten())
    t1 = time.time()
    return t1-t0, blobs, net, torch.from_numpy(blobs.data.numpy())
예제 #2
0
import torch as th
import torchvision as tv
import torch.nn as nn
import torch.optim as optim

from torch.autograd import Variable as V
from torchvision import transforms

from lenet import LeNet
from sobolev import SobolevLoss

USE_SOBOLEV = False

student = LeNet()
teacher = LeNet()
teacher.load_state_dict(th.load('teacher.pth'))
student = student.cuda()
teacher = teacher.cuda()

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
def main():

    parser = argparse.ArgumentParser()
    mode_group = parser.add_mutually_exclusive_group(required=True)
    mode_group.add_argument("--train",
                            action="store_true",
                            help="To train the network.")
    mode_group.add_argument("--test",
                            action="store_true",
                            help="To test the network.")
    parser.add_argument("--epochs",
                        default=10,
                        type=int,
                        help="Desired number of epochs.")
    parser.add_argument("--dropout",
                        action="store_true",
                        help="Whether to use dropout or not.")
    parser.add_argument("--uncertainty",
                        action="store_true",
                        help="Use uncertainty or not.")
    parser.add_argument("--dataset",
                        action="store_true",
                        help="The dataset to use.")
    parser.add_argument("--outsample",
                        action="store_true",
                        help="Use out of sample test image")

    uncertainty_type_group = parser.add_mutually_exclusive_group()
    uncertainty_type_group.add_argument(
        "--mse",
        action="store_true",
        help=
        "Set this argument when using uncertainty. Sets loss function to Expected Mean Square Error."
    )
    uncertainty_type_group.add_argument(
        "--digamma",
        action="store_true",
        help=
        "Set this argument when using uncertainty. Sets loss function to Expected Cross Entropy."
    )
    uncertainty_type_group.add_argument(
        "--log",
        action="store_true",
        help=
        "Set this argument when using uncertainty. Sets loss function to Negative Log of the Expected Likelihood."
    )

    dataset_type_group = parser.add_mutually_exclusive_group()
    dataset_type_group.add_argument(
        "--mnist",
        action="store_true",
        help="Set this argument when using MNIST dataset")
    dataset_type_group.add_argument(
        "--emnist",
        action="store_true",
        help="Set this argument when using EMNIST dataset")
    dataset_type_group.add_argument(
        "--CIFAR",
        action="store_true",
        help="Set this argument when using CIFAR dataset")
    dataset_type_group.add_argument(
        "--fmnist",
        action="store_true",
        help="Set this argument when using FMNIST dataset")
    args = parser.parse_args()

    if args.dataset:
        if args.mnist:
            from mnist import dataloaders, label_list
        elif args.CIFAR:
            from CIFAR import dataloaders, label_list
        elif args.fmnist:
            from fashionMNIST import dataloaders, label_list

    if args.train:
        num_epochs = args.epochs
        use_uncertainty = args.uncertainty
        num_classes = 10
        model = LeNet(dropout=args.dropout)

        if use_uncertainty:
            if args.digamma:
                criterion = edl_digamma_loss
            elif args.log:
                criterion = edl_log_loss
            elif args.mse:
                criterion = edl_mse_loss
            else:
                parser.error(
                    "--uncertainty requires --mse, --log or --digamma.")
        else:
            criterion = nn.CrossEntropyLoss()

        optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.005)

        exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer,
                                                     step_size=7,
                                                     gamma=0.1)

        device = get_device()
        model = model.to(device)

        model, metrics = train_model(model,
                                     dataloaders,
                                     num_classes,
                                     criterion,
                                     optimizer,
                                     scheduler=exp_lr_scheduler,
                                     num_epochs=num_epochs,
                                     device=device,
                                     uncertainty=use_uncertainty)

        state = {
            "epoch": num_epochs,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        }

        if use_uncertainty:
            if args.digamma:
                torch.save(state, "./results/model_uncertainty_digamma.pt")
                print("Saved: ./results/model_uncertainty_digamma.pt")
            if args.log:
                torch.save(state, "./results/model_uncertainty_log.pt")
                print("Saved: ./results/model_uncertainty_log.pt")
            if args.mse:
                torch.save(state, "./results/model_uncertainty_mse.pt")
                print("Saved: ./results/model_uncertainty_mse.pt")

        else:
            torch.save(state, "./results/model.pt")
            print("Saved: ./results/model.pt")

    elif args.test:

        use_uncertainty = args.uncertainty
        device = get_device()
        model = LeNet()
        model = model.to(device)
        optimizer = optim.Adam(model.parameters())

        if use_uncertainty:
            if args.digamma:
                checkpoint = torch.load(
                    "./results/model_uncertainty_digamma.pt")
            if args.log:
                checkpoint = torch.load("./results/model_uncertainty_log.pt")
            if args.mse:
                checkpoint = torch.load("./results/model_uncertainty_mse.pt")
        else:
            checkpoint = torch.load("./results/model.pt")

        filename = "./results/rotate.jpg"
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

        model.eval()
        if args.outsample:
            img = Image.open("./data/arka.jpg").convert('L').resize((28, 28))
            img = TF.to_tensor(img)
            img.unsqueeze_(0)
        else:
            a = iter(dataloaders['test'])
            img, label = next(a)
        rotating_image_classification(model,
                                      img,
                                      filename,
                                      label_list,
                                      uncertainty=use_uncertainty)

        img = transforms.ToPILImage()(img[0][0])
        test_single_image(model, img, label_list, uncertainty=use_uncertainty)
예제 #4
0
파일: main.py 프로젝트: afcidk/opencv-1
class Ui(QtWidgets.QMainWindow):
    buttons = [
       "load_image", "color_conversion", "image_flipping",
       "blending", "global_threshold", "local_threshold",
       "gaussian", "sobel_x", "sobel_y", "magnitude", "rst",
       "show_train_image", "show_hyper", "train_1", "pt",
       "inference", "ok", "show_train_result", "cancel"]
    inputs = ["angle", "scale", "tx", "ty", "test_index"]

    def __init__(self):
        super(Ui, self).__init__()

        uic.loadUi('main_window.ui', self)
        self.get_widgets()
        self.get_input()
        self.bind_event()
        self.param_setup()
        self.torch_setup()
        self.show()

    def get_widgets(self):
        for btn in self.buttons:
            setattr(self, btn, self.findChild(QtWidgets.QPushButton, btn))
    
    def get_input(self):
        for inp in self.inputs:
            setattr(self, inp, self.findChild(QtWidgets.QLineEdit, inp))


    def bind_event(self):
        for btn in self.buttons:
            getattr(self, btn).clicked.connect(partial(
                getattr(events,  btn), 
                self))
    def param_setup(self):
        self.batch_size = 32
        self.learning_rate = 0.001
        self.opt = "SGD"
        self.loss_list = []
        self.loss_epoch = []
        self.acc_train_epoch = []
        self.acc_test_epoch = []
        self.compose = transforms.Compose([
            transforms.Resize((32,32)),
            transforms.ToTensor()
        ])


    def torch_setup(self):
        self.data_train = MNIST('./data/mnist',
                            train=True,
                            download=True,
                            transform=self.compose)
                            
        self.data_test = MNIST('./data/mnist',
                            train=False,
                            download=True,
                            transform=self.compose)
        self.data_train_loader = DataLoader(self.data_train, batch_size=self.batch_size, shuffle=True, num_workers=4)
        self.data_test_loader = DataLoader(self.data_test, batch_size=self.batch_size, num_workers=4)
        self.criterion = nn.CrossEntropyLoss()
        self.net = LeNet()
        self.optimizer = getattr(optim, self.opt)(self.net.parameters(), lr=self.learning_rate)

        try:
            self.net.load_state_dict(load('model_params.pkl'))
            self.loaded = True
            print("Loaded")
        except Exception as e:
            print(e)
            self.loaded = False
            print("Not loaded")

    def train(self, epoch):
        self.net.train()
        self.loss_list = []
        correct, total = 0, 0
        for i, (images, labels) in enumerate(self.data_train_loader):
            self.optimizer.zero_grad()
            output = self.net(images)
            loss = self.criterion(output, labels)
            pred = output.data.max(1, keepdim=True)[1]
            correct += np.sum(np.squeeze(pred.eq(labels.data.view_as(pred))).cpu().numpy())
            total += images.size(0)
            self.loss_list.append(loss.detach().cpu().item())

            if i % 100 == 0:
                print(f'Train - Epoch {epoch}, Batch: {i}, Loss: {loss.detach().cpu().item()}')

            loss.backward()
            self.optimizer.step()
        self.acc_train_epoch.append(correct/total)
        self.loss_epoch.append(sum(self.loss_list)/len(self.loss_list))

    def test(self):
        self.net.eval()
        total_correct, avg_loss = 0, 0.0
        for i, (images, labels) in enumerate(self.data_test_loader):
            output = self.net(images)
            avg_loss += self.criterion(output, labels).sum()
            pred = output.detach().max(1)[1]
            total_correct += pred.eq(labels.view_as(pred)).sum()

        avg_loss /= len(self.data_test)
        acc = float(total_correct)/len(self.data_test)
        self.acc_test_epoch.append(acc)

    def test_and_train(self, epoch):
        self.train(epoch)
        self.test()
예제 #5
0
    int_version = 4
    name = 'LeNet{}x{}_{}'.format(input_size_h, input_size_w, int_version)

    #hardware setting
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = 'cpu'  # now it works only under 'cpu' settings

    net = LeNet(in_channel, number_classes)

    #if(device == 'cuda'):
    net = net.to(device)

    # load a pre-trained pyTorch model
    #checkpoint = torch.load("./lenet32x40_ckpt_0_60.667504.pth")
    checkpoint = torch.load("./ckpt_32x40_lenet_3_277_99.206645.pth")
    net.load_state_dict(checkpoint['weight'])

    # if u want to use cpu, then you need to do something
    # net = net.to('cpu')
    # input_ = input_.to('cpu')

    net.eval()

    input_ = torch.ones([1, in_channel, input_size_h, input_size_w])
    input = input_.to(device)
    # input=torch.ones([1,3,224,224])
    # cuda problem ...
    pytorch_to_caffe.trans_net(net, input, name)

    pytorch_to_caffe.save_prototxt('{}.prototxt'.format(name))
    pytorch_to_caffe.save_caffemodel('{}.caffemodel'.format(name))
예제 #6
0
                        default=False,
                        help="resize images, only for lenet is true.")
    parser.add_argument("--img_store",
                        type=str,
                        default="data/visualsem_images")
    parser.add_argument("--marking_dict",
                        type=str,
                        default="data/marking_dict.json",
                        help="marking dict")
    args = parser.parse_args()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    if args.model == "lenet":
        net = LeNet(args.width).to(device)
        net.load_state_dict(torch.load("data/lenet_test3000_19"))
    elif args.model == "resnet":
        net = ResNet152().to(device)
        net.load_state_dict(torch.load("data/resnet_test3000_26"))
    elif args.model == "vgg":
        net = VGG19_BN().to(device)
        net.load_state_dict(torch.load("data/vgg_test3000_28"))

    net.eval()
    transform = TRANSFORMS[MODELS[args.model]]

    with open(args.file_data, "r") as f:
        val = json.loads(f.read())

    #with open("../marking_dict.json", "r") as f:
#        marking_dict = json.loads(f.read())
예제 #7
0
def training(model_name, trainloader, validloader, input_channel=3, epochs=1, resume=True, self_define=True, only_print=False):
    # load self defined or official net
    assert model_name in ["LeNet", "VGG16", "ResNet", "DenseNet"]

    if self_define:
        if model_name == "LeNet":
            net = LeNet(input_channel)
        elif model_name == "VGG16":
            net = VGG16(input_channel)
        elif model_name == "ResNet":
            net = ResNet(input_channel)
        elif model_name == "DenseNet":
            net = DenseNet(input_channel)
    else:
        if model_name == "LeNet":
            net = LeNet(input_channel)  # on official LeNet
        elif model_name == "VGG16":
            net = models.vgg16_bn(pretrained=False, num_classes=10)
        elif model_name == "ResNet":
            net = models.resnet50(pretrained=False, num_classes=10)
        elif model_name == "DenseNet":
            net = models.DenseNet(num_classes=10)

    # sum of net parameters number
    print("Number of trainable parameters in %s : %f" % (model_name, sum(p.numel() for p in net.parameters() if p.requires_grad)))

    # print model structure
    if only_print:
        print(net)
        return

    # resume training
    param_path = "./model/%s_%s_parameter.pt" % (model_name, "define" if self_define else "official")
    if resume:
        if os.path.exists(param_path):
            net.load_state_dict(torch.load(param_path))
            net.train()
            print("Resume training " + model_name)
        else:
            print("Train %s from scratch" % model_name)
    else:
        print("Train %s from scratch" % model_name)

    # define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    # train on GPU
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('train on %s' % device)
    net.to(device)

    running_loss = 0.0
    train_losses = []
    valid_losses = []
    mini_batches = 125 * 5
    for epoch in range(epochs):
        for i, data in enumerate(trainloader, 0):
            # get one batch
            # inputs, labels = data
            inputs, labels = data[0].to(device), data[1].to(device)
    
            # switch model to training mode, clear gradient accumulators
            net.train()
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
            # print statistics
            running_loss += loss.item()
            if i % mini_batches == mini_batches - 1:  # print and valid every <mini_batches> mini-batches
                # validate model in validation dataset
                valid_loss = valid(net, validloader, criterion, device)
                print('[%d, %5d] train loss: %.3f,  validset loss: %.3f' % (
                    epoch + 1, i + 1, running_loss / mini_batches, valid_loss))
                train_losses.append(running_loss / mini_batches)
                valid_losses.append(valid_loss)
                running_loss = 0.0

        # save parameters
        torch.save(net.state_dict(), param_path)

        # # save checkpoint
        # torch.save({
        #     'epoch': epoch,
        #     'model_state_dict': net.state_dict(),
        #     'optimizer_state_dict': optimizer.state_dict(),
        #     'loss': loss
        # }, "./checkpoints/epoch_" + str(epoch) + ".tar")
    
    print('Finished Training, %d images in all' % (len(train_losses) * batch_size * mini_batches / epochs))
    
    # draw loss curve
    assert len(train_losses) == len(valid_losses)
    loss_x = range(0, len(train_losses))
    plt.plot(loss_x, train_losses, label="train loss")
    plt.plot(loss_x, valid_losses, label="valid loss")
    plt.title("Loss for every %d mini-batch" % mini_batches)
    plt.xlabel("%d mini-batches" % mini_batches)
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(model_name + "_loss.png")
    plt.show()
def main():

    parser = argparse.ArgumentParser()
    mode_group = parser.add_mutually_exclusive_group(required=True)
    mode_group.add_argument("--train",
                            action="store_true",
                            help="To train the network.")
    mode_group.add_argument("--test",
                            action="store_true",
                            help="To test the network.")
    mode_group.add_argument("--examples",
                            action="store_true",
                            help="To example MNIST data.")
    parser.add_argument("--epochs",
                        default=10,
                        type=int,
                        help="Desired number of epochs.")
    parser.add_argument("--dropout",
                        action="store_true",
                        help="Whether to use dropout or not.")
    parser.add_argument("--uncertainty",
                        action="store_true",
                        help="Use uncertainty or not.")
    uncertainty_type_group = parser.add_mutually_exclusive_group()
    uncertainty_type_group.add_argument(
        "--mse",
        action="store_true",
        help=
        "Set this argument when using uncertainty. Sets loss function to Expected Mean Square Error."
    )
    uncertainty_type_group.add_argument(
        "--digamma",
        action="store_true",
        help=
        "Set this argument when using uncertainty. Sets loss function to Expected Cross Entropy."
    )
    uncertainty_type_group.add_argument(
        "--log",
        action="store_true",
        help=
        "Set this argument when using uncertainty. Sets loss function to Negative Log of the Expected Likelihood."
    )
    args = parser.parse_args()

    if args.examples:
        examples = enumerate(dataloaders["val"])
        batch_idx, (example_data, example_targets) = next(examples)
        fig = plt.figure()
        for i in range(6):
            plt.subplot(2, 3, i + 1)
            plt.tight_layout()
            plt.imshow(example_data[i][0], cmap="gray", interpolation="none")
            plt.title("Ground Truth: {}".format(example_targets[i]))
            plt.xticks([])
            plt.yticks([])
        plt.savefig("./images/examples.jpg")

    elif args.train:
        num_epochs = args.epochs
        use_uncertainty = args.uncertainty
        num_classes = 10

        model = LeNet(dropout=args.dropout)

        if use_uncertainty:
            if args.digamma:
                criterion = edl_digamma_loss
            elif args.log:
                criterion = edl_log_loss
            elif args.mse:
                criterion = edl_mse_loss
            else:
                parser.error(
                    "--uncertainty requires --mse, --log or --digamma.")
        else:
            criterion = nn.CrossEntropyLoss()

        optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.005)

        exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer,
                                                     step_size=7,
                                                     gamma=0.1)

        device = get_device()
        model = model.to(device)

        model, metrics = train_model(model,
                                     dataloaders,
                                     num_classes,
                                     criterion,
                                     optimizer,
                                     scheduler=exp_lr_scheduler,
                                     num_epochs=num_epochs,
                                     device=device,
                                     uncertainty=use_uncertainty)

        state = {
            "epoch": num_epochs,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        }

        if use_uncertainty:
            if args.digamma:
                torch.save(state, "./results/model_uncertainty_digamma.pt")
                print("Saved: ./results/model_uncertainty_digamma.pt")
            if args.log:
                torch.save(state, "./results/model_uncertainty_log.pt")
                print("Saved: ./results/model_uncertainty_log.pt")
            if args.mse:
                torch.save(state, "./results/model_uncertainty_mse.pt")
                print("Saved: ./results/model_uncertainty_mse.pt")

        else:
            torch.save(state, "./results/model.pt")
            print("Saved: ./results/model.pt")

    elif args.test:

        use_uncertainty = args.uncertainty
        device = get_device()
        model = LeNet()
        model = model.to(device)
        optimizer = optim.Adam(model.parameters())

        if use_uncertainty:
            if args.digamma:
                checkpoint = torch.load(
                    "./results/model_uncertainty_digamma.pt")
                filename = "./results/rotate_uncertainty_digamma.jpg"
            if args.log:
                checkpoint = torch.load("./results/model_uncertainty_log.pt")
                filename = "./results/rotate_uncertainty_log.jpg"
            if args.mse:
                checkpoint = torch.load("./results/model_uncertainty_mse.pt")
                filename = "./results/rotate_uncertainty_mse.jpg"

        else:
            checkpoint = torch.load("./results/model.pt")
            filename = "./results/rotate.jpg"

        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

        model.eval()

        rotating_image_classification(model,
                                      digit_one,
                                      filename,
                                      uncertainty=use_uncertainty)

        img = Image.open("./data/one.jpg").convert('L')

        test_single_image(model, img, uncertainty=use_uncertainty)
예제 #9
0
def evalidation(model_name, testloader, classes, input_channel=3, self_define=True):
    dataiter = iter(testloader)
    images, labels = dataiter.next()

    # print images
    imshow(torchvision.utils.make_grid(images))
    print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))

    # load model parameter
    assert model_name in ["LeNet", "VGG16", "ResNet", "DenseNet"]
    param_path = "./model/%s_%s_parameter.pt" % (model_name, "define" if self_define else "official")
    print("load model parameter from %s" % param_path)
    if self_define:
        if model_name == "LeNet":
            net = LeNet(input_channel)
        elif model_name == "VGG16":
            net = VGG16(input_channel)
        elif model_name == "ResNet":
            net = ResNet(input_channel)
        elif model_name == "DenseNet":
            net = DenseNet(input_channel)
    else:
        if model_name == "LeNet":
            net = LeNet(input_channel)
        elif model_name == "VGG16":
            net = models.vgg16_bn(pretrained=False, num_classes=10)
        elif model_name == "ResNet":
            net = models.resnet50(pretrained=False, num_classes=10)
        elif model_name == "DenseNet":
            net = models.DenseNet(num_classes=10)


    net.load_state_dict(torch.load(param_path))
    net.eval()

    # predict
    outputs = net(images)
    _, predicted = torch.max(outputs, 1)
    print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(batch_size)))

    # to gpu
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net.to(device)

    # evaluate
    class_correct = np.zeros(10)
    class_total = np.zeros(10)
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs, 1)

            for i in range(batch_size):
                label = labels[i]
                class_total[label] += 1
                if predicted[i] == label:
                    class_correct[label] += 1

    print("\nEvery class precious: \n ",
          ' '.join("%5s : %2d %%\n" % (classes[i], 100 * class_correct[i]/class_total[i]) for i in range(len(classes))))
    print("\n%d images in all, Total precious: %2d %%"
          % (np.sum(class_total), 100 * np.sum(class_correct) / np.sum(class_total)))
예제 #10
0
    def get_soft_label(self, img):
        img = Image.fromarray(img.numpy(), mode='L')
        if self.transform is not None:
            img = self.transform(img)
        return self.init_target_transform(img)


leNetModel = LeNet(args)

if args.cuda:
    leNetModel.cuda()

soft_labels = []

if importMode:
    leNetModel.load_state_dict(
        torch.load(os.path.join('./result', 'lenet5_best.json')))
else:
    for epoch in range(1, args.epochs + 1):
        leNetModel.train_(train_loader, epoch)


def get_soft_label(img):
    return leNetModel.forward_with_temperature(img.view(
        1, *(img.size()))).view(-1)


soft_train_loader = torch.utils.data.DataLoader(SoftMNIST(
    './data',
    import_targets=True,
    train=True,
    transform=transforms.Compose(
net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()  # 选择损失函数

# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)  # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6,
                                            gamma=0.1)  # 设置学习率下降策略

# ============================ step 5+/5 断点恢复 ============================

path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)

net.load_state_dict(checkpoint['model_state_dict'])

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

start_epoch = checkpoint['epoch']

scheduler.last_epoch = start_epoch

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

for epoch in range(start_epoch + 1, MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
예제 #12
0
from sklearn.metrics import confusion_matrix  # 生成混淆矩阵函数
import matplotlib.pyplot as plt  # 绘图库
import numpy as np
import torch
import torch.nn as nn
from lenet import LeNet
from dataset import Dataset
batch_size = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#导入模型类、数据集类
net = LeNet()
net.load_state_dict(torch.load("./model/model.pkl"))


# #实例化模型
# my_model = MyModel.to(device)
# #加载模型
# my_model.load_state_dict(torch.load("./result/model.pth"))
def predict():
    net.eval()
    y_true_list = []
    y_pred_list = []
    predict_dataset = Dataset(train=False, batch_size=batch_size)
    for i, (data, labels) in enumerate(predict_dataset):
        with torch.no_grad():
            data = data.to(device)
            y_true = labels.numpy()
            outputs = net(data)
            y_pred = outputs.max(-1)[-1]
            y_pred = y_pred.cpu().data.numpy()
            for i in range(batch_size):