示例#1
0
def create_models():
    unet_o = unet(
        feature_scale=4,
        n_classes=2,
        is_deconv=True,
        in_channels=3,
        is_batchnorm=True,
    )
    unet_1 = GeneralUNet(
        feature_scale=4,
        out_channels=2,
        is_deconv=True,
        in_channels=3,
        is_batchnorm=True,
        feature_level=4,
    )
    unet_2 = GeneralUNet_v2(feature_scale=4,
                            out_channels=2,
                            is_deconv=True,
                            in_channels=3,
                            is_norm=True,
                            feature_level=4)

    unify_parameters(unet_o, unet_1)
    unify_parameters(unet_o, unet_2)

    return unet_o, unet_1, unet_2
示例#2
0
def test_unet_only_hidden():
    parser = train_parser()
    args = parser.parse_args()
    unet = UNetOnlyHidden(args, n_classes=2)

    inp = torch.rand(size=[3, 3, 32, 32])
    out = unet(inp)
示例#3
0
def testing_input():
    inp = torch.rand(size=[3, 3, 32, 32])
    unets = create_models()
    outs = []
    for unet in unets:
        outs.append(unet(inp))
    final_out = torch.cat(outs, dim=0)

    assert ((outs[2] - outs[0]).mean() < 1e-6)
    print((outs[1] - outs[0]).mean())
    print((outs[1] - outs[2]).mean())
示例#4
0
def test_runet_with_different_level():
    parser = train_parser()
    args = parser.parse_args()
    inp = torch.rand(size=[3, 3, 64, 64])
    for i in range(3, 5):
        args.recurrent_level = i
        args.unet_level = 5 - i
        unet = GeneralRecurrentUnet(
            args,
            2,
        )
        print(unet)
        out = unet(inp)
        _test_recurrent_output(out)
示例#5
0
    def __init__(self,
                 x_size,
                 y_size,
                 hidden_size,
                 dropout_rate=0,
                 normalize=True,
                 c_max_len=64):
        super(OuterNet, self).__init__()
        self.normalize = normalize
        self.hidden_size = hidden_size
        self.dropout_rate = dropout_rate
        self.c_max_len = c_max_len

        self.w_s = nn.Linear(x_size, hidden_size, False)
        self.w_e = nn.Linear(x_size, hidden_size, False)
        #
        self.dropout = nn.Dropout(p=self.dropout_rate)
        #
        self.seg_net = unet(feature_scale=8, in_channels=1, n_classes=1)
        # self.seg_net = MyUNet(feature_scale=8, in_channels=1, n_classes=1)
        # self.seg_net = PSP(1, backbone='resnet50', aux=False, root='~/.gluoncvth/models')
        # self.seg_net = MyPSPNet()
        # self.seg_net = nn.Conv2d(1, 1, 1)

        self.w1 = nn.Linear(x_size, hidden_size, False)
        self.w2 = nn.Linear(hidden_size, c_max_len, False)

        self.w_transform = nn.Sequential(nn.Linear(x_size, 128, False),
                                         nn.ReLU(),
                                         nn.Linear(128, x_size, False),
                                         nn.Tanh())

        self.w_t = nn.Linear(x_size, x_size, False)

        self.w_t1 = nn.Linear(x_size, c_max_len, False)
        self.w_t2 = nn.Linear(x_size, c_max_len, False)

        self.row_embeds = nn.Embedding(c_max_len, 2)
        self.col_embeds = nn.Embedding(c_max_len, 2)
示例#6
0
def train(model):

    if model == 'unet':
        model = unet(feature_scale=feature_scale,
                     n_classes=n_classes,
                     is_batchnorm=True,
                     in_channels=3,
                     is_deconv=True)

    if model == 'segnet':
        model = segnet(n_classes=n_classes, in_channels=3, is_unpooling=True)

    if model == 'fcn32':
        model = fcn32s(n_classes=n_classes)
        vgg16 = models.vgg16(pretrained=True)
        model.init_vgg16_params(vgg16)

    if model == 'fcn16':
        model = fcn16s(n_classes=n_classes)
        vgg16 = models.vgg16(pretrained=True)
        model.init_vgg16_params(vgg16)

    if model == 'fcn8':
        model = fcn8s(n_classes=n_classes)
        vgg16 = models.vgg16(pretrained=True)
        model.init_vgg16_params(vgg16)

    pascal = pascalVOCLoader(data_path, is_transform=True, img_size=img_rows)
    trainloader = data.DataLoader(pascal, batch_size=batch_size, num_workers=4)

    if torch.cuda.is_available():
        model.cuda(0)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=l_rate,
                                momentum=0.99,
                                weight_decay=5e-4)

    test_image, test_segmap = pascal[0]
    test_image = Variable(test_image.unsqueeze(0).cuda(0))
    vis = visdom.Visdom()

    for epoch in range(n_epoch):
        for i, (images, labels) in enumerate(trainloader):
            if torch.cuda.is_available():
                images = Variable(images.cuda(0))
                labels = Variable(labels.cuda(0))
            else:
                images = Variable(images)
                labels = Variable(labels)

            optimizer.zero_grad()
            outputs = model(images)

            loss = cross_entropy2d(outputs, labels)

            loss.backward()
            optimizer.step()

            if (i + 1) % 20 == 0:
                print("Epoch [%d/%d] Loss: %.4f" %
                      (epoch + 1, n_epoch, loss.data[0]))

        test_output = model(test_image)
        predicted = pascal.decode_segmap(
            test_output[0].cpu().data.numpy().argmax(0))
        target = pascal.decode_segmap(test_segmap.numpy())

        vis.image(test_image[0].cpu().data.numpy(),
                  opts=dict(title='Input' + str(epoch)))
        vis.image(np.transpose(target, [2, 0, 1]),
                  opts=dict(title='GT' + str(epoch)))
        vis.image(np.transpose(predicted, [2, 0, 1]),
                  opts=dict(title='Predicted' + str(epoch)))

    torch.save(model, "unet_voc_" + str(feature_scale) + ".pkl")