def main():

    model = GoogLeNet(num_classes=num_clazz,
                      aux_logits=True,
                      init_weights=True).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
    criterion = nn.CrossEntropyLoss().to(device)

    validation_acc = []
    best_acc, best_epoch = 0, 0
    global_step = 0

    for epoch in range(epochs):
        model.train()  # 训练模式
        total_batch_loss = 0
        print('start')
        # print(train_loader[:1])
        for step, data in enumerate(train_loader):
            x, y = data
            x, y = x.to(device), y.to(device)
            logits, aux_logits2, aux_logits1 = model(x)
            loss0 = criterion(logits, y)
            loss1 = criterion(aux_logits1, y)
            loss2 = criterion(aux_logits2, y)
            loss = loss0 + loss1 * 0.3 + loss2 * 0.3
            total_batch_loss += loss.item()
            # 梯度清零
            optimizer.zero_grad()
            # 计算梯度
            loss.backward()
            # 更新参数
            optimizer.step()

            if step % 200 == 0:
                print('Step {}/{} \t loss: {}'.format(step, len(train_loader),
                                                      loss))

        # eval模式
        model.eval()
        val_acc = evalute(model, val_loader)
        if val_acc > best_acc:
            best_epoch = epoch
            best_acc = val_acc
            torch.save(model.state_dict(), 'best.mdl')

        scheduler.step()  # 调整学习率
        print("epoch: ", epoch, "epoch_loss: ", total_batch_loss, "epoch_acc:",
              val_acc)

    print('best acc:', best_acc, 'best epoch:', best_epoch)

    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')

    test_acc = evalute(model, test_loader)
    print('test acc:', test_acc)
Esempio n. 2
0
def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # load image
    img_path = "../tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(
        img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(
        json_path)

    json_file = open(json_path, "r")
    class_indict = json.load(json_file)

    # create model
    model = GoogLeNet(num_classes=5, aux_logits=False).to(device)
    # * 初始化的时候,aux_logits=False,不需要搭建辅助分类器

    # load model weights
    weights_path = "./googleNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(
        weights_path)
    missing_keys, unexpected_keys = model.load_state_dict(torch.load(
        weights_path, map_location=device),
                                                          strict=False)
    # * strict默认是True,表示精准匹配 当前模型和需要载入的权重模型 之间的结构
    # * strict=False, 因为刚刚保存模型的时候已经保存了2个辅助分类器,而此时不需要,所以设为False

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(
        class_indict[str(predict_cla)], predict[predict_cla].numpy())
    plt.title(print_res)
    print(print_res)
    plt.show()
def load_model():
    model = GoogLeNet()
    model.load_state_dict(torch.load(PRETRAINED_MODEL))
    return model
Esempio n. 4
0
# load image
img = Image.open("../tulip.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)

# read class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

# create model
model = GoogLeNet(num_classes=5, aux_logits=False)
# load model weights
model_weight_path = "./googleNet.pth"
missing_keys, unexpected_keys = model.load_state_dict(torch.load(model_weight_path), strict=False)
model.eval()
with torch.no_grad():
    # predict class
    output = torch.squeeze(model(img))
    predict = torch.softmax(output, dim=0)
    predict_cla = torch.argmax(predict).numpy()
print(predict.numpy()[predict_cla])
print(class_indict[str(predict_cla)])
plt.show()
Esempio n. 5
0
def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load image
    img_path = "./a.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    json_file = open(json_path, "r")
    class_indict = json.load(json_file)

    # create model
    model = GoogLeNet(num_classes=5, aux_logits=False).to(device)

    # load model weights
    weights_path = "./weights/googlenet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    missing_keys, unexpected_keys = model.load_state_dict(torch.load(weights_path, map_location=device),
                                                          strict=False)

    model.eval()
    print('=================================')
    dummy_input = torch.randn(1, 3, 224, 224).to(device)
    torch.onnx.export(
        model,
        dummy_input,
        'googlenet.onnx',
        dynamic_axes={'image': {0: 'B'}, 'outputs': {0: 'B'}},
        input_names=['image'],
        output_names=['outputs'],
        opset_version=12
    )
    print('=================================')

    print('---------------------------------')
    traced_script_module = torch.jit.trace(model, dummy_input)
    traced_script_module.save("googlenet.pt")
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    print(print_res)
    plt.show()