def load_trained_model(model_name, weights_path, num_classes):
    if model_name == 'segnet':
        model = SegNet(num_classes)
        checkpoint = torch.load(weights_path, map_location='cpu')
        model.load_state_dict(checkpoint['state_dict'])
    else:
        raise AssertionError('model not available')

    return model
예제 #2
0
    model2 = torchvision.models.segmentation.deeplabv3_resnet101(
        pretrained=False, num_classes=1)
    model2 = model2.to(device)
    modelName2 = model2.__class__.__name__
    model3 = UNet(1, 1).to(device)
    modelName3 = model3.__class__.__name__

    model1_checkpoint = torch.load(
        train().checkpointsPath + '/' + modelName1 + '/' +
        '2019-08-30 13:21:52.559302_epoch-5_dice-0.4743926368317377.pth')
    model2_checkpoint = torch.load(
        train().checkpointsPath + '/' + modelName2 + '/' +
        '2019-08-22 08:37:06.839794_epoch-1_dice-0.4479589270841744.pth')
    model3_checkpoint = torch.load(
        train().checkpointsPath + '/' + modelName3 + '/' +
        '2019-09-03 03:21:05.647040_epoch-253_dice-0.46157537277322264.pth')

    model1.load_state_dict(model1_checkpoint['model_state_dict'])
    model2.load_state_dict(model2_checkpoint['model_state_dict'])
    model3.load_state_dict(model3_checkpoint['model_state_dict'])

    try:
        # Create model Directory
        train().main(model1, model2, model3, device)
    except KeyboardInterrupt:
        print('Interrupted')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
예제 #3
0
                              transforms=transformations,
                              output_size=(720, 720),
                              predct=True)
# cv_loader = DataLoader(dataset=cv_dataset, batch_size=1, shuffle=False, num_workers=8, drop_last=True)
# 定义预测函数
# print(cv_dataset.datalen)
# cm = np.array(COLORMAP).astype('uint8')
# cm = np.array(CamVid_colours).astype('uint8')
n_class = 9
net = SegNet(num_classes=9)
# net=SegNet(num_classes=12)
net.cuda()
net.eval()
dir = './checkpoints/baiduSegNet5.pth'
state = t.load(dir)
net.load_state_dict(state['net'])
test_data, test_label = cv_dataset[1]
print(test_data.size())

# out=net(Variable(test_data.unsqueeze(0)).cuda())
# print(out.data.size())
# pred = out.max(1)[1].squeeze().cpu().data.numpy()
# print(pred.shape)


def predict(im, label):  # 预测结果
    im = Variable(im.unsqueeze(0)).cuda()
    out = net(im)
    pred = out.max(1)[1].squeeze().cpu().data.numpy()
    # pred = cm[pred]
def main():
    net = SegNet(num_classes=num_classes).cuda()
    if len(train_args['snapshot']) == 0:
        curr_epoch = 0
    else:
        print 'training resumes from ' + train_args['snapshot']
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1])
        train_record['best_val_loss'] = float(split_snapshot[3])
        train_record['corr_mean_iu'] = float(split_snapshot[6])
        train_record['corr_epoch'] = curr_epoch

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    train_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)),
        simul_transforms.RandomCrop(train_args['input_size']),
        simul_transforms.RandomHorizontallyFlip()
    ])
    val_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)),
        simul_transforms.CenterCrop(train_args['input_size'])
    ])
    img_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = standard_transforms.Compose([
        expanded_transforms.MaskToTensor(),
        expanded_transforms.ChangeLabel(ignored_label, num_classes - 1)
    ])
    restore_transform = standard_transforms.Compose([
        expanded_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    train_set = CityScapes('train',
                           simul_transform=train_simul_transform,
                           transform=img_transform,
                           target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=train_args['batch_size'],
                              num_workers=16,
                              shuffle=True)
    val_set = CityScapes('val',
                         simul_transform=val_simul_transform,
                         transform=img_transform,
                         target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=val_args['batch_size'],
                            num_workers=16,
                            shuffle=False)

    weight = torch.ones(num_classes)
    weight[num_classes - 1] = 0
    criterion = CrossEntropyLoss2d(weight).cuda()

    # don't use weight_decay for bias
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and 'dec' in name
        ],
        'lr':
        2 * train_args['new_lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and 'dec' in name
        ],
        'lr':
        train_args['new_lr'],
        'weight_decay':
        train_args['weight_decay']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and 'dec' not in name
        ],
        'lr':
        2 * train_args['pretrained_lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and 'dec' not in name
        ],
        'lr':
        train_args['pretrained_lr'],
        'weight_decay':
        train_args['weight_decay']
    }],
                          momentum=0.9,
                          nesterov=True)

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['new_lr']
        optimizer.param_groups[1]['lr'] = train_args['new_lr']
        optimizer.param_groups[2]['lr'] = 2 * train_args['pretrained_lr']
        optimizer.param_groups[3]['lr'] = train_args['pretrained_lr']

    if not os.path.exists(ckpt_path):
        os.mkdir(ckpt_path)
    if not os.path.exists(os.path.join(ckpt_path, exp_name)):
        os.mkdir(os.path.join(ckpt_path, exp_name))

    for epoch in range(curr_epoch, train_args['epoch_num']):
        train(train_loader, net, criterion, optimizer, epoch)
        validate(val_loader, net, criterion, optimizer, epoch,
                 restore_transform)
예제 #5
0
            images = sample_test[0].to(device)
            trueMasks = sample_test[1].to(device)
            predMasks = model(images)

            plt.figure()
            predTensor = (torch.exp(predMasks[0, 0, :, :]).detach().cpu())
            plt.imshow((predTensor / torch.max(predTensor)) * 255, cmap='gray')
            pilTrans = transforms.ToPILImage()
            pilImg = pilTrans((predTensor / torch.max(predTensor)) * 255)
            pilArray = np.array(pilImg)
            pilArray = (pilArray > 127)
            im = Image.fromarray(pilArray)
            im.save(self.predMaskPath + '/' + str(i_test) + '.tif')

            print((predTensor / torch.max(predTensor)) * 255)

            mBatchDice = torch.mean(Loss(trueMasks, predMasks).dice_coeff())
            print(mBatchDice.item())


if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SegNet(1, 1).to(device)
    # model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=False, num_classes=1)
    modelName = model.__class__.__name__
    checkpoint = torch.load(test().checkpointsPath + '/' + modelName + '/' +
                            test().modelWeight)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    test().main(model, device)