Exemple #1
0
def load_model(opt, checkpoint_dir):
    checkpoint_list = glob.glob(os.path.join(checkpoint_dir, "*.pth"))
    checkpoint_list.sort()

    if opt.resume_best:
        loss_list = list(
            map(lambda x: float(os.path.basename(x).split('_')[4][:-4]),
                checkpoint_list))
        best_loss_idx = loss_list.index(min(loss_list))
        checkpoint_pth = checkpoint_list[best_loss_idx]
    else:
        checkpoint_pth = checkpoint_list[len(checkpoint_list) - 1]

    net = VGG(opt)

    if os.path.isfile(checkpoint_pth):
        print("=> loading checkpoint '{}'".format(checkpoint_pth))
        checkpoint = torch.load(checkpoint_pth)

        n_epoch = checkpoint['epoch']
        net.load_state_dict(checkpoint['net'].state_dict())
        print("=> loaded checkpoint '{}'(epoch {})".format(
            checkpoint_pth, n_epoch))
    else:
        print("=> no checkpoint found at {}".format(checkpoint_pth))
        n_epoch = 0

    return n_epoch + 1, net
def test_whole(config):

    # Dataset
    multiclass_train = Multiclass_Dataset('train')
    multiclass_val = Multiclass_Dataset('val')
    multiclass_test = Multiclass_Dataset('test')

    # Dataloader
    multiclass_train_loader = DataLoader(multiclass_train,
                                         batch_size=2 * config.batch_size,
                                         shuffle=True)
    multiclass_test_loader = DataLoader(multiclass_test,
                                        batch_size=2 * config.batch_size,
                                        shuffle=True)
    multiclass_val_loader = DataLoader(multiclass_val,
                                       batch_size=2 * config.batch_size,
                                       shuffle=True)

    # for epoch in range(config.epochs):
    #     # Load trained model parameters
    #     model_infection = VGG()
    #     model_infection_path = './checkpoints/' + str(epoch) + '_params_infection.pth'
    #     model_infection.load_state_dict(copy.deepcopy(torch.load(model_infection_path, config.device)))
    #     model_infection.to(config.device)
    #     model_covid = VGG()
    #     model_covid_path = './checkpoints/' + str(epoch) + '_params_covid.pth'
    #     model_covid.load_state_dict(copy.deepcopy(torch.load(model_covid_path, config.device)))
    #     model_covid.to(config.device)

    #     train_accuracy_whole = test(multiclass_train_loader, model_infection, model_covid)
    #     test_accuracy_whole = test(multiclass_test_loader, model_infection, model_covid)
    #     val_accuracy_whole = test(multiclass_val_loader, model_infection, model_covid)
    #     print(f"epoch {epoch}: train_acc_overall: {train_accuracy_whole}, val_acc_overall: {val_accuracy_whole}, test_acc_overall: {test_accuracy_whole}")

    # Evaluate the performance using the combination of the best models
    model_infection = VGG()
    model_infection_path = './checkpoints/' + 'best_params_infection.pth'
    model_infection.load_state_dict(
        copy.deepcopy(torch.load(model_infection_path, config.device)))
    model_infection.to(config.device)
    model_covid = VGG()
    model_covid_path = './checkpoints/' + 'best_params_covid.pth'
    model_covid.load_state_dict(
        copy.deepcopy(torch.load(model_covid_path, config.device)))
    model_covid.to(config.device)

    train_accuracy_whole = test(multiclass_train_loader, model_infection,
                                model_covid)
    test_accuracy_whole = test(multiclass_test_loader, model_infection,
                               model_covid)
    val_accuracy_whole = test(multiclass_val_loader, model_infection,
                              model_covid)
    print()
    print(
        f"Combination of best models: train_acc_overall: {train_accuracy_whole}, val_acc_overall: {val_accuracy_whole}, test_acc_overall: {test_accuracy_whole}"
    )
Exemple #3
0
def train():
    train_dataloader, val_dataloader = loadData()
    pretrained_params = torch.load('VGG_pretrained.pth')
    model = VGG()
    # strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃。
    model.load_state_dict(pretrained_params.state_dict(), strict=False)

    if torch.cuda.is_available():
        model.cuda()

    # finetune 时冻结XXlayer的参数


#    for p in model.XXlayers.parameters():
#        p.requires_grad = False

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    loss_func = nn.CrossEntropyLoss()
    best_acc = 0

    for epoch in range(epochs):
        epoch_loss = 0
        steps = 0
        for i, data in enumerate(train_dataloader):
            inputs, labels = data
            if torch.cuda.is_available():
                inputs, labels = inputs.cuda(), labels.cuda()
            inputs, labels = Variable(inputs), Variable(labels)
            model.train()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_func(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.data[0]
            steps += 1
        print('epoch:%d loss:%.3f' % (epoch + 1, epoch_loss / steps))
        if epoch % 5 == 0:
            val_acc = evaluate(model, val_dataloader)
            if val_acc > best_acc:
                best_acc = val_acc
                torch.save(model, 'best_VGG.pkl')
                torch.save(model.state_dict(), 'best_VGG_params.pkl')
            print('test acc:'.format(val_acc))

    print('Finished Training')
    torch.save(model, 'VGG.pkl')
    torch.save(model.state_dict(), 'VGG_params.pkl')
Exemple #4
0
    if visual_heatmap:
        plt.matshow(heatmap)
        plt.show()
 
    img = cv2.imread(img_path)  # 用cv2加载原始图像
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))  # 将热力图的大小调整为与原始图像相同
    heatmap = np.uint8(255 * heatmap)  # 将热力图转换为RGB格式
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # 将热力图应用于原始图像
    superimposed_img = heatmap * 0.4 + img  # 这里的0.4是热力图强度因子
    cv2.imwrite(save_path, superimposed_img)  # 将图像保存到硬盘

    
if __name__ == '__main__':
    # parser
    # setting
    torch.manual_seed(11)
    device = torch.device(0)
    # data loader
    data_path = './data'

    lr = 0.001
    model = VGG().to(device)
    ckpt = torch.load('./model/vgg.pt')
    model.load_state_dict(ckpt['model_state_dict'])
    draw_CAM(model=model, img_path='./data/COVID/Covid (1007).png', save_path='heat.png', transform=None, visual_heatmap=True)
    #model.eval()
    #optimizer = optim.Adam(model.parameters(), lr=lr)
    #metric = nn.CrossEntropyLoss().to(device)
    #acc = test(model, device, test_loader, metric)
    #print(acc)
args = parser.parse_args()

args.eval == True
args.train == False

my_lr = 0.002
net = VGG()
#net = ResNet()
start_epoch = 1
start_step = 1

if args.resume or args.demo or args.eval:
    checkpoint = torch.load("../../models/base/vgg_base_model.pth")
    start_epoch = checkpoint['epoch']
    net.load_state_dict(checkpoint['model'])
    my_lr = checkpoint['lr']
    start_step = checkpoint['step']
    lost = checkpoint['lost']
    error = checkpoint['error']
    print(lost)

print(net)

net = net.to(device)

criterion = nn.CrossEntropyLoss()


def get_error(scores, labels):
Exemple #6
0
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    # Model
    print('======> Building model...')
    net = VGG('VGG19')
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    if config.resume:
        # Load checkpoint.
        print('====> Resuming from checkpoint..')
        assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('./checkpoint/ckpt.t7')
        net.load_state_dict(checkpoint['net'])
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch']

    LossFunc = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=config.lr, momentum=0.9, weight_decay=5e-4)

    for epoch in range(start_epoch, start_epoch+50):
        train(epoch)
        test(epoch)
        
        file = open('Loss_seq2.txt', 'wb')
        pickle.dump(loss_seq, file)
        file.close()

        file = open('Acc_seq2.txt', 'wb')
Exemple #7
0
parser.add_argument('--rootDir',
                    type=str,
                    default='Data/',
                    help='Directory path to root of the training data')
parser.add_argument('--checkpoint',
                    type=str,
                    default='checkpoint/best_vgg.pth')

opt = parser.parse_args()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

idx_to_classes = {
    index: cavity
    for index, cavity in enumerate(os.walk(opt.rootDir).__next__()[1])
}
transformation = transforms.Compose(
    [transforms.Resize((224, 224)),
     transforms.ToTensor()])

for file in os.listdir('test'):
    image = transformation(
        Image.open(os.path.join(opt.imagePath,
                                file)).convert("RGB")).to(device).unsqueeze(0)

    model = CNN(num_classes=len(idx_to_classes)).to(device)
    model.load_state_dict(
        torch.load(opt.checkpoint, map_location=lambda storage, loc: storage))

    _, predict = torch.max(model(image).data.cpu(), 1)
    print("It's a {}!".format(idx_to_classes[predict.item()]), file)
Exemple #8
0
import torchvision.transforms as transforms
import PIL.Image as Image
import torch.optim as optim
import torch.nn.functional as F
from tensorboardX import SummaryWriter

from data_loader import get_data_loader
from model import VGG
from focal_loss import FocalLoss

if __name__ == "__main__":
    r = 16
    thresh = 0.9
    # model
    net = VGG(1).cuda()
    net.load_state_dict(torch.load("./modules/vgg-160-1.433.pth.tar")['state_dict'])
    net.eval()

    # test image
    image = Image.open("./predictions/test5.jpg")

    # to patch tensor
    trans = transforms.Compose([transforms.Pad(r, padding_mode='symmetric'), transforms.Grayscale(), transforms.ToTensor()])
    image = torch.squeeze(trans(image))
    w, h = image.shape
    print(image.shape)

    patches = [image[i - r:i + r, j - r:j + r] for i in range(r, w - r) for j in range(r, h - r)]
    a = patches[23]
    print(len(patches))