Пример #1
0
        correct += (pred == labels).sum().item()

    print('Test set: Average loss: {:.4f}, Accuracy:{}/{} ({:.2f}%)'.format(
        test_loss, correct, len(test_loader.dataset), 100. * correct/len(test_loader.dataset)))

    outfile.write('Test set: Average loss: {:.4f}, Accuracy:{}/{} ({:.2f}%)'.format(
        test_loss, correct, len(test_loader.dataset), 100. * correct/len(test_loader.dataset)))
    outfile.write('\n')
    return 100. * correct/len(test_loader.dataset)

transform = tr.Compose([
    tr.RandomHorizontalFlip(p=0.5),
    #tr.ToTensor()
    ])

train_dataset = dataloader.RetinopathyLoader('data/', 'train', transform=transform)

test_dataset  = dataloader.RetinopathyLoader('data/', 'test')


train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

#import sys
if __name__ == '__main__':
    
    # start training
    #model = Resnet18()
    print(resnet_model)
    train_acc, test_acc = train(resnet_model, optimizer)
    torch.save({ 'state_dict': resnet_model.state_dict()}, 'resnet50_with_pretrain_5.tar')
Пример #2
0
                    help="weight_decay(L2 penalty)")
parser.add_argument('--epochs',
                    default=10,
                    type=int,
                    help="number of total epochs to run")
parser.add_argument('--lr',
                    '--learning-rate',
                    default=1e-3,
                    type=float,
                    help="initial learning rate")
parser.add_argument('--checkpoint', type=str, help="name of checkpoint file")
parser.add_argument('--resume', type=str, help="name of checkpoint file")
args = parser.parse_args()

# DataLoader
train_dataset = dataloader.RetinopathyLoader("./data/", 'train')
train_dataloader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)
test_dataset = dataloader.RetinopathyLoader("./data/", 'test')
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size)

# Model
model = None
if args.model == "ResNet18":
    print('Use ResNet18')
    #  model = resnet.resnet18(num_classes=5)
    model = torchvision.models.resnet18()
    num_ftrs = model.fc.in_features
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    model.fc = nn.Linear(num_ftrs, num_classes)
Пример #3
0
    print('Test set: Average loss: {:.4f}, Accuracy:{}/{} ({:.2f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    outfile.write(
        'Test set: Average loss: {:.4f}, Accuracy:{}/{} ({:.2f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))
    cm.plot_confusion_matrix(np.array(preds),
                             np.array(test_dataset.label),
                             np.array(['0', '1', '2', '3', '4']),
                             normalize=True)
    return 100. * correct / len(test_loader.dataset)


train_dataset = dataloader.RetinopathyLoader('data/', 'train')

test_dataset = dataloader.RetinopathyLoader('data/', 'test')

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

load_model = models.ResNet().to(device)
#import sys
if __name__ == '__main__':

    # start training