コード例 #1
0
def get_model(name, n_classes):
    model = fcn8s(n_classes)
    #model = unet(n_classes)
    #model  = model(n_classes=n_classes)
    vgg16  = models.vgg16(pretrained=True)
    model.init_vgg16_params(vgg16)
  
    return model
コード例 #2
0
#print(data.max(), data.min())
#print(masks.max(), masks.min())
#show_sample3(data, masks)
#print(labels)

test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=conf['batch_size'])

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
# torch.cuda.set_device(6)
torch.cuda.set_device(conf['gpu'])
print(device)

dataset_sizes = {'test': len(test_dataset)}

from ptsemseg.models.fcn import fcn8s

model = fcn8s(n_classes=1)
vgg16 = models.vgg16(pretrained=True)
model = model.to(device)
if conf['num_gpus'] > 1:
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

model_ft.load_state_dict(torch.load("./models/FCN8_ep4.net"))

from train import accuracy, iou

test(model, device, test_loader, epoch, dataset_sizes['test'], conf)
コード例 #3
0
ファイル: train_pascal.py プロジェクト: weigq/pytorch-semseg
def train(model):

    if model == 'unet':
        model = unet(feature_scale=feature_scale,
                     n_classes=n_classes,
                     is_batchnorm=True,
                     in_channels=3,
                     is_deconv=True)

    if model == 'segnet':
        model = segnet(n_classes=n_classes, in_channels=3, is_unpooling=True)

    if model == 'fcn32':
        model = fcn32s(n_classes=n_classes)
        vgg16 = models.vgg16(pretrained=True)
        model.init_vgg16_params(vgg16)

    if model == 'fcn16':
        model = fcn16s(n_classes=n_classes)
        vgg16 = models.vgg16(pretrained=True)
        model.init_vgg16_params(vgg16)

    if model == 'fcn8':
        model = fcn8s(n_classes=n_classes)
        vgg16 = models.vgg16(pretrained=True)
        model.init_vgg16_params(vgg16)

    pascal = pascalVOCLoader(data_path, is_transform=True, img_size=img_rows)
    trainloader = data.DataLoader(pascal, batch_size=batch_size, num_workers=4)

    if torch.cuda.is_available():
        model.cuda(0)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=l_rate,
                                momentum=0.99,
                                weight_decay=5e-4)

    test_image, test_segmap = pascal[0]
    test_image = Variable(test_image.unsqueeze(0).cuda(0))
    vis = visdom.Visdom()

    for epoch in range(n_epoch):
        for i, (images, labels) in enumerate(trainloader):
            if torch.cuda.is_available():
                images = Variable(images.cuda(0))
                labels = Variable(labels.cuda(0))
            else:
                images = Variable(images)
                labels = Variable(labels)

            optimizer.zero_grad()
            outputs = model(images)

            loss = cross_entropy2d(outputs, labels)

            loss.backward()
            optimizer.step()

            if (i + 1) % 20 == 0:
                print("Epoch [%d/%d] Loss: %.4f" %
                      (epoch + 1, n_epoch, loss.data[0]))

        test_output = model(test_image)
        predicted = pascal.decode_segmap(
            test_output[0].cpu().data.numpy().argmax(0))
        target = pascal.decode_segmap(test_segmap.numpy())

        vis.image(test_image[0].cpu().data.numpy(),
                  opts=dict(title='Input' + str(epoch)))
        vis.image(np.transpose(target, [2, 0, 1]),
                  opts=dict(title='GT' + str(epoch)))
        vis.image(np.transpose(predicted, [2, 0, 1]),
                  opts=dict(title='Predicted' + str(epoch)))

    torch.save(model, "unet_voc_" + str(feature_scale) + ".pkl")