示例#1
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    # Check if the save directory exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    model = SegNet(3, 3)

    #model.features = torch.nn.DataParallel(model.features)
    if use_gpu:
        model.cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            #optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # data_transforms = {
    #     'train': transforms.Compose([
    #         transforms.Scale(256),
    #         transforms.RandomSizedCrop(224),
    #         transforms.RandomHorizontalFlip(),
    #         transforms.ToTensor(),
    #         transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    #     ]),
    #     'val': transforms.Compose([
    #         transforms.Scale(256),
    #         transforms.CenterCrop(224),
    #         transforms.RandomHorizontalFlip(),
    #         transforms.ToTensor(),
    #         transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    #     ]),
    # }

    data_transforms = {
        'train':
        transforms.Compose([
            transforms.Scale((224, 224)),
            transforms.ToTensor(),
        ]),
        'val':
        transforms.Compose([
            transforms.Scale((224, 224)),
            transforms.ToTensor(),
        ]),
    }

    data_dir = '/media/salman/DATA/NUST/MS RIME/Thesis/MICCAI Dataset/miccai_all_images'

    image_datasets = {
        x: miccaiDataset(os.path.join(data_dir, x), data_transforms[x])
        for x in ['train', 'val']
    }

    dataloaders = {
        x: torch.utils.data.DataLoader(image_datasets[x],
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.workers)
        for x in ['train', 'val']
    }
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

    # Define loss function (criterion) and optimizer
    criterion = nn.MSELoss().cuda()

    if args.half:
        model.half()
        criterion.half()

    #optimizer = torch.optim.SGD(model.parameters(), args.lr,
    #                            momentum=args.momentum,
    #                            weight_decay=args.weight_decay)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    if args.evaluate:
        validate(dataloaders['val'], model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        #adjust_learning_rate(optimizer, epoch)

        # Train for one epoch
        train(dataloaders['train'], model, criterion, optimizer, epoch)

        # Evaulate on validation set
        prec1 = validate(dataloaders['val'], model, criterion)
        prec1 = prec1.cpu().data.numpy()

        # Remember best prec1 and save checkpoint
        print(prec1)
        print(best_prec1)
        is_best = prec1 < best_prec1
        best_prec1 = min(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                #'optimizer': optimizer.state_dict(),
            },
            is_best,
            filename=os.path.join(args.save_dir,
                                  'checkpoint_{}.tar'.format(epoch)))
示例#2
0
def handler(context):
    # Dataset
    dataset_alias = context.datasets
    train_dataset_id = dataset_alias['train']
    val_dataset_id = dataset_alias['val']

    trainset = SegmentationDatasetFromAPI(train_dataset_id,
                                          transform=SegNetAugmentation(MEANS))
    valset = SegmentationDatasetFromAPI(val_dataset_id,
                                        transform=SegNetAugmentation(
                                            MEANS, False))
    class_weight = calc_weight(
        SegmentationDatasetFromAPI(train_dataset_id,
                                   transform=SegNetAugmentation(MEANS, False)))
    class_weight = class_weight.to(device)

    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=BATCHSIZE,
                                              shuffle=True,
                                              num_workers=0)
    valloader = torch.utils.data.DataLoader(valset,
                                            batch_size=BATCHSIZE,
                                            shuffle=False,
                                            num_workers=0)

    # Model
    net = SegNet(3, n_class=len(camvid_label_names))
    net = net.to(device)

    # Optimizer
    #criterion = PixelwiseSoftmaxClassifier(weight=class_weight)
    criterion = torch.nn.CrossEntropyLoss(weight=class_weight, ignore_index=-1)
    optimizer = optim.SGD(net.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=5e-4)
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=[150, 250],
                                         gamma=0.1)

    statistics = Statistics(epochs)

    for epoch in range(epochs):
        scheduler.step()

        train_loss, train_acc = train(net, optimizer, trainloader, criterion,
                                      epoch)
        test_loss, test_acc = test(net, valloader, criterion, epoch)

        # Reporting
        print(
            '[{:d}] main/loss: {:.3f} main/acc: {:.3f}, main/validation/loss: {:.3f}, main/validation/acc: {:.3f}'
            .format(epoch + 1, train_loss, train_acc, test_loss, test_acc))

        statistics(epoch + 1, train_loss, train_acc, test_loss, test_acc)
        writer.add_scalar('main/loss', train_loss, epoch + 1)
        writer.add_scalar('main/acc', train_acc, epoch + 1)
        writer.add_scalar('main/validation/loss', test_loss, epoch + 1)
        writer.add_scalar('main/validation/acc', test_acc, epoch + 1)

    torch.save(net.state_dict(),
               os.path.join(ABEJA_TRAINING_RESULT_DIR, 'model.pth'))
示例#3
0
def transfer_pretrained_weighted():
    model = SegNet()
    corresp_name = {
        'features.0.weight': 'vgg16_block1.0.weight',
        'features.0.bias': 'vgg16_block1.0.bias',
        'features.1.weight': 'vgg16_block1.1.weight',
        'features.1.bias': 'vgg16_block1.1.bias',
        'features.1.running_mean': 'vgg16_block1.1.running_mean',
        'features.1.running_var': 'vgg16_block1.1.running_var',
        'features.3.weight': 'vgg16_block1.3.weight',
        'features.3.bias': 'vgg16_block1.3.bias',
        'features.4.weight': 'vgg16_block1.4.weight',
        'features.4.bias': 'vgg16_block1.4.bias',
        'features.4.running_mean': 'vgg16_block1.4.running_mean',
        'features.4.running_var': 'vgg16_block1.4.running_var',
        'features.7.weight': 'vgg16_block2.0.weight',
        'features.7.bias': 'vgg16_block2.0.bias',
        'features.8.weight': 'vgg16_block2.1.weight',
        'features.8.bias': 'vgg16_block2.1.bias',
        'features.8.running_mean': 'vgg16_block2.1.running_mean',
        'features.8.running_var': 'vgg16_block2.1.running_var',
        'features.10.weight': 'vgg16_block2.3.weight',
        'features.10.bias': 'vgg16_block2.3.bias',
        'features.11.weight': 'vgg16_block2.4.weight',
        'features.11.bias': 'vgg16_block2.4.bias',
        'features.11.running_mean': 'vgg16_block2.4.running_mean',
        'features.11.running_var': 'vgg16_block2.4.running_var',
        'features.14.weight': 'vgg16_block3.0.weight',
        'features.14.bias': 'vgg16_block3.0.bias',
        'features.15.weight': 'vgg16_block3.1.weight',
        'features.15.bias': 'vgg16_block3.1.bias',
        'features.15.running_mean': 'vgg16_block3.1.running_mean',
        'features.15.running_var': 'vgg16_block3.1.running_var',
        'features.17.weight': 'vgg16_block3.3.weight',
        'features.17.bias': 'vgg16_block3.3.bias',
        'features.18.weight': 'vgg16_block3.4.weight',
        'features.18.bias': 'vgg16_block3.4.bias',
        'features.18.running_mean': 'vgg16_block3.4.running_mean',
        'features.18.running_var': 'vgg16_block3.4.running_var',
        'features.20.weight': 'vgg16_block3.6.weight',
        'features.20.bias': 'vgg16_block3.6.bias',
        'features.21.weight': 'vgg16_block3.7.weight',
        'features.21.bias': 'vgg16_block3.7.bias',
        'features.21.running_mean': 'vgg16_block3.7.running_mean',
        'features.21.running_var': 'vgg16_block3.7.running_var',
        'features.24.weight': 'vgg16_block4.0.weight',
        'features.24.bias': 'vgg16_block4.0.bias',
        'features.25.weight': 'vgg16_block4.1.weight',
        'features.25.bias': 'vgg16_block4.1.bias',
        'features.25.running_mean': 'vgg16_block4.1.running_mean',
        'features.25.running_var': 'vgg16_block4.1.running_var',
        'features.27.weight': 'vgg16_block4.3.weight',
        'features.27.bias': 'vgg16_block4.3.bias',
        'features.28.weight': 'vgg16_block4.4.weight',
        'features.28.bias': 'vgg16_block4.4.bias',
        'features.28.running_mean': 'vgg16_block4.4.running_mean',
        'features.28.running_var': 'vgg16_block4.4.running_var',
        'features.30.weight': 'vgg16_block4.6.weight',
        'features.30.bias': 'vgg16_block4.6.bias',
        'features.31.weight': 'vgg16_block4.7.weight',
        'features.31.bias': 'vgg16_block4.7.bias',
        'features.31.running_mean': 'vgg16_block4.7.running_mean',
        'features.31.running_var': 'vgg16_block4.7.running_var',
        'features.34.weight': 'vgg16_block5.0.weight',
        'features.34.bias': 'vgg16_block5.0.bias',
        'features.35.weight': 'vgg16_block5.1.weight',
        'features.35.bias': 'vgg16_block5.1.bias',
        'features.35.running_mean': 'vgg16_block5.1.running_mean',
        'features.35.running_var': 'vgg16_block5.1.running_var',
        'features.37.weight': 'vgg16_block5.3.weight',
        'features.37.bias': 'vgg16_block5.3.bias',
        'features.38.weight': 'vgg16_block5.4.weight',
        'features.38.bias': 'vgg16_block5.4.bias',
        'features.38.running_mean': 'vgg16_block5.4.running_mean',
        'features.38.running_var': 'vgg16_block5.4.running_var',
        'features.40.weight': 'vgg16_block5.6.weight',
        'features.40.bias': 'vgg16_block5.6.bias',
        'features.41.weight': 'vgg16_block5.7.weight',
        'features.41.bias': 'vgg16_block5.7.bias',
        'features.41.running_mean': 'vgg16_block5.7.running_mean',
        'features.41.running_var': 'vgg16_block5.7.running_var',
    }
    s_dict = model.state_dict()
    pretrained_dict = torch.load(
        'vgg16_bn-6c64b313.pth'
    )  # you have to download pretrained model weight pth
    for name in pretrained_dict:
        if name not in corresp_name:
            continue
        s_dict[corresp_name[name]] = pretrained_dict[name]
    model.load_state_dict(s_dict)
    torch.save(model.state_dict(), 'transfer-vgg16-for11classes.pth')