Beispiel #1
0
def CreateModel(args):
    if args.model == 'DeepLab':
        model = Deeplab(num_classes=args.num_classes, init_weights=args.init_weights, restore_from=args.restore_from, phase=args.set)
        if args.set == 'train':
            optimizer = optim.SGD(model.optim_parameters(args),
                                  lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
            optimizer.zero_grad()
            return model, optimizer
        else:
            return model
        
    if args.model == 'VGG':
        model = VGG16_FCN8s(num_classes=19, init_weights=args.init_weights, restore_from=args.restore_from)
        if args.set == 'train':
            optimizer = optim.Adam(
            [
                {'params': model.get_parameters(bias=False)},
                {'params': model.get_parameters(bias=True),
                 'lr': args.learning_rate * 2}
            ],
            lr=args.learning_rate,
            betas=(0.9, 0.99))
            optimizer.zero_grad()
            return model, optimizer
        else:
            return model
Beispiel #2
0
def CreateSSLModel(args):
    if args.model == 'DeepLab':
        model = Deeplab(num_classes=args.num_classes, init_weights=args.init_weights, restore_from=args.restore_from, phase=args.set)
    elif args.model == 'VGG':
        model = VGG16_FCN8s(num_classes=19, init_weights=args.init_weights, restore_from=args.restore_from)
    else:
        raise ValueError('The model mush be either deeplab-101 or vgg16-fcn')
    return model
Beispiel #3
0
def CreateModel(args):
    """
    Returns:
        model: based on the command line arguments
    """
    if args.model == 'DeepLab':
        phase = 'test'
        if args.set == 'train' or args.set == 'trainval':
            phase = 'train'
        model = Deeplab(num_classes=args.num_classes,
                        init_weights=args.init_weights,
                        restore_from=args.restore_from,
                        phase=phase)

        if args.set == 'train' or args.set == 'trainval':
            optimizer = optim.SGD(model.optim_parameters(args),
                                  lr=args.learning_rate,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
            optimizer.zero_grad()
            return model, optimizer
        else:
            return model

    if args.model == 'VGG':
        model = VGG16_FCN8s(num_classes=19,
                            init_weights=args.init_weights,
                            restore_from=args.restore_from)
        if args.set == 'train' or args.set == 'trainval':
            optimizer = optim.Adam([{
                'params': model.get_parameters(bias=False)
            }, {
                'params': model.get_parameters(bias=True),
                'lr': args.learning_rate * 2
            }],
                                   lr=args.learning_rate,
                                   betas=(0.9, 0.99))
            optimizer.zero_grad()
            return model, optimizer
        else:
            return model

    if args.model == 'CLS':
        model = CLSNet(restore_from=args.restore_from)
        if args.set == 'train' or args.set == 'trainval':
            optimizer = optim.SGD(model.optim_parameters(args),
                                  lr=args.learning_rate,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
            optimizer.zero_grad()
            return model, optimizer
        else:
            return model
Beispiel #4
0
def CreateSSLModel(args):
    """
    Raises:
        ValueError: args.model should be either deeplab-101 or vgg16-fcn

    Returns:
        model: based on command line arguments
    """
    if args.model == 'DeepLab':
        model = Deeplab(num_classes=args.num_classes,
                        init_weights=args.init_weights,
                        restore_from=args.restore_from,
                        phase=args.set)
    elif args.model == 'VGG':
        model = VGG16_FCN8s(num_classes=19,
                            init_weights=args.init_weights,
                            restore_from=args.restore_from)
    else:
        raise ValueError('The model mush be either deeplab-101 or vgg16-fcn')
    return model
Beispiel #5
0
def CreateModel(args):

    if args.model == 'enet':
        model = ENet(num_classes=args.num_classes).cuda()
        optimizer = optim.Adam(
            model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

    elif args.model == 'deeplab':
        model = Deeplab(num_classes=args.num_classes).cuda()
        optimizer = None

    elif args.model == 'frrnet':
        model = FRRNet(out_channels=args.num_classes).cuda()
        optimizer = optim.Adam(
            model.parameters(), lr=args.learning_rate)

    """
    elif args.model == 'fcn8s':
        model = FCN8s(pretrained_net=None, n_class=args.num_classes).cuda()
        optimizer = optim.SGD(
            model.parameters(), lr=args.learning_rate)
    """
    
    else:
import torch
from torchsummaryX import summary

from model.utils.block import Block
from model.encoder import _Entry, _Middle, _Exit, Encoder
from model.deeplab import Deeplab
from model.decoder import Decoder

if __name__ == "__main__":
    # model = Block(728, 728)
    # model = _Entry()
    # model = _Middle()
    # model = _Exit()
    # model = Encoder()
    # model = Decoder(128,20)
    model = Deeplab()
    model.to('cuda')
    model.eval()
    x = torch.Tensor(1, 3, 512, 512).cuda()
    output = model(x)
    print(output.shape)
    # print(model)