示例#1
0
from segnet import SegNet
import time
import matplotlib.pyplot as plt
import argparse

parse = argparse.ArgumentParser()
parse.add_argument('--mode', choices=['train', 'val', 'test'])
parse.add_argument('--batch_size', '-b', type=int, default=16)
parse.add_argument('--resume', type=bool, default=False)
parse.add_argument('--epochs', type=int, default=33)
args = parse.parse_args()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = SegNet()
if args.mode == 'train':
    model.load_state_dict(torch.load('transfer-vgg16-for11classes.pth'))
else:
    model.load_state_dict(torch.load('segnet_weight_11classes.pth'))

if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)
model.to(device)


def train(epochs):
    model.train()
    train_dataset = CamVidDataset(phase='train')
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
示例#2
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    # Check if the save directory exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    model = SegNet(3, 3)

    #model.features = torch.nn.DataParallel(model.features)
    if use_gpu:
        model.cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            #optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # data_transforms = {
    #     'train': transforms.Compose([
    #         transforms.Scale(256),
    #         transforms.RandomSizedCrop(224),
    #         transforms.RandomHorizontalFlip(),
    #         transforms.ToTensor(),
    #         transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    #     ]),
    #     'val': transforms.Compose([
    #         transforms.Scale(256),
    #         transforms.CenterCrop(224),
    #         transforms.RandomHorizontalFlip(),
    #         transforms.ToTensor(),
    #         transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    #     ]),
    # }

    data_transforms = {
        'train':
        transforms.Compose([
            transforms.Scale((224, 224)),
            transforms.ToTensor(),
        ]),
        'val':
        transforms.Compose([
            transforms.Scale((224, 224)),
            transforms.ToTensor(),
        ]),
    }

    data_dir = '/media/salman/DATA/NUST/MS RIME/Thesis/MICCAI Dataset/miccai_all_images'

    image_datasets = {
        x: miccaiDataset(os.path.join(data_dir, x), data_transforms[x])
        for x in ['train', 'val']
    }

    dataloaders = {
        x: torch.utils.data.DataLoader(image_datasets[x],
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.workers)
        for x in ['train', 'val']
    }
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

    # Define loss function (criterion) and optimizer
    criterion = nn.MSELoss().cuda()

    if args.half:
        model.half()
        criterion.half()

    #optimizer = torch.optim.SGD(model.parameters(), args.lr,
    #                            momentum=args.momentum,
    #                            weight_decay=args.weight_decay)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    if args.evaluate:
        validate(dataloaders['val'], model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        #adjust_learning_rate(optimizer, epoch)

        # Train for one epoch
        train(dataloaders['train'], model, criterion, optimizer, epoch)

        # Evaulate on validation set
        prec1 = validate(dataloaders['val'], model, criterion)
        prec1 = prec1.cpu().data.numpy()

        # Remember best prec1 and save checkpoint
        print(prec1)
        print(best_prec1)
        is_best = prec1 < best_prec1
        best_prec1 = min(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                #'optimizer': optimizer.state_dict(),
            },
            is_best,
            filename=os.path.join(args.save_dir,
                                  'checkpoint_{}.tar'.format(epoch)))
示例#3
0
def load_model():
    segnet_save_path = os.path.join(model_dir,'segnet_model.pth')
    # model = torch.load(segnet_save_path)
    model = SegNet(input_channel,output_channel).cuda()
    model.load_state_dict(torch.load(segnet_save_path))
    return model
示例#4
0
def transfer_pretrained_weighted():
    model = SegNet()
    corresp_name = {
        'features.0.weight': 'vgg16_block1.0.weight',
        'features.0.bias': 'vgg16_block1.0.bias',
        'features.1.weight': 'vgg16_block1.1.weight',
        'features.1.bias': 'vgg16_block1.1.bias',
        'features.1.running_mean': 'vgg16_block1.1.running_mean',
        'features.1.running_var': 'vgg16_block1.1.running_var',
        'features.3.weight': 'vgg16_block1.3.weight',
        'features.3.bias': 'vgg16_block1.3.bias',
        'features.4.weight': 'vgg16_block1.4.weight',
        'features.4.bias': 'vgg16_block1.4.bias',
        'features.4.running_mean': 'vgg16_block1.4.running_mean',
        'features.4.running_var': 'vgg16_block1.4.running_var',
        'features.7.weight': 'vgg16_block2.0.weight',
        'features.7.bias': 'vgg16_block2.0.bias',
        'features.8.weight': 'vgg16_block2.1.weight',
        'features.8.bias': 'vgg16_block2.1.bias',
        'features.8.running_mean': 'vgg16_block2.1.running_mean',
        'features.8.running_var': 'vgg16_block2.1.running_var',
        'features.10.weight': 'vgg16_block2.3.weight',
        'features.10.bias': 'vgg16_block2.3.bias',
        'features.11.weight': 'vgg16_block2.4.weight',
        'features.11.bias': 'vgg16_block2.4.bias',
        'features.11.running_mean': 'vgg16_block2.4.running_mean',
        'features.11.running_var': 'vgg16_block2.4.running_var',
        'features.14.weight': 'vgg16_block3.0.weight',
        'features.14.bias': 'vgg16_block3.0.bias',
        'features.15.weight': 'vgg16_block3.1.weight',
        'features.15.bias': 'vgg16_block3.1.bias',
        'features.15.running_mean': 'vgg16_block3.1.running_mean',
        'features.15.running_var': 'vgg16_block3.1.running_var',
        'features.17.weight': 'vgg16_block3.3.weight',
        'features.17.bias': 'vgg16_block3.3.bias',
        'features.18.weight': 'vgg16_block3.4.weight',
        'features.18.bias': 'vgg16_block3.4.bias',
        'features.18.running_mean': 'vgg16_block3.4.running_mean',
        'features.18.running_var': 'vgg16_block3.4.running_var',
        'features.20.weight': 'vgg16_block3.6.weight',
        'features.20.bias': 'vgg16_block3.6.bias',
        'features.21.weight': 'vgg16_block3.7.weight',
        'features.21.bias': 'vgg16_block3.7.bias',
        'features.21.running_mean': 'vgg16_block3.7.running_mean',
        'features.21.running_var': 'vgg16_block3.7.running_var',
        'features.24.weight': 'vgg16_block4.0.weight',
        'features.24.bias': 'vgg16_block4.0.bias',
        'features.25.weight': 'vgg16_block4.1.weight',
        'features.25.bias': 'vgg16_block4.1.bias',
        'features.25.running_mean': 'vgg16_block4.1.running_mean',
        'features.25.running_var': 'vgg16_block4.1.running_var',
        'features.27.weight': 'vgg16_block4.3.weight',
        'features.27.bias': 'vgg16_block4.3.bias',
        'features.28.weight': 'vgg16_block4.4.weight',
        'features.28.bias': 'vgg16_block4.4.bias',
        'features.28.running_mean': 'vgg16_block4.4.running_mean',
        'features.28.running_var': 'vgg16_block4.4.running_var',
        'features.30.weight': 'vgg16_block4.6.weight',
        'features.30.bias': 'vgg16_block4.6.bias',
        'features.31.weight': 'vgg16_block4.7.weight',
        'features.31.bias': 'vgg16_block4.7.bias',
        'features.31.running_mean': 'vgg16_block4.7.running_mean',
        'features.31.running_var': 'vgg16_block4.7.running_var',
        'features.34.weight': 'vgg16_block5.0.weight',
        'features.34.bias': 'vgg16_block5.0.bias',
        'features.35.weight': 'vgg16_block5.1.weight',
        'features.35.bias': 'vgg16_block5.1.bias',
        'features.35.running_mean': 'vgg16_block5.1.running_mean',
        'features.35.running_var': 'vgg16_block5.1.running_var',
        'features.37.weight': 'vgg16_block5.3.weight',
        'features.37.bias': 'vgg16_block5.3.bias',
        'features.38.weight': 'vgg16_block5.4.weight',
        'features.38.bias': 'vgg16_block5.4.bias',
        'features.38.running_mean': 'vgg16_block5.4.running_mean',
        'features.38.running_var': 'vgg16_block5.4.running_var',
        'features.40.weight': 'vgg16_block5.6.weight',
        'features.40.bias': 'vgg16_block5.6.bias',
        'features.41.weight': 'vgg16_block5.7.weight',
        'features.41.bias': 'vgg16_block5.7.bias',
        'features.41.running_mean': 'vgg16_block5.7.running_mean',
        'features.41.running_var': 'vgg16_block5.7.running_var',
    }
    s_dict = model.state_dict()
    pretrained_dict = torch.load(
        'vgg16_bn-6c64b313.pth'
    )  # you have to download pretrained model weight pth
    for name in pretrained_dict:
        if name not in corresp_name:
            continue
        s_dict[corresp_name[name]] = pretrained_dict[name]
    model.load_state_dict(s_dict)
    torch.save(model.state_dict(), 'transfer-vgg16-for11classes.pth')