Exemplo n.º 1
0
p['momentum'] = 0.9  # Momentum
p['epoch_size'] = 2  # How many epochs to change learning rate

p['Model'] = 'deeplab'  # Choose model: unet or deeplab
backbone = 'xception'  # For deeplab only: Use xception or resnet as feature extractor,
num_of_classes = 2
imsize = 512  # 256 or 512
output_stride = 8 # 8 or 16, 8 is better. Controls output stride of the deeplab model, which increases resolution of convolutions.
numInputChannels = 3

# Network definition
if p['Model'] == 'deeplab':
    if backbone == 'xception':
        net = deeplab_xception.DeepLabv3_plus(nInputChannels=numInputChannels, n_classes=num_of_classes, os=output_stride, pretrained=True)
    elif backbone == 'resnet':
        net = deeplab_resnet.DeepLabv3_plus(nInputChannels=numInputChannels, n_classes=num_of_classes, os=output_stride, pretrained=True)
    else:
        raise NotImplementedError
    modelName = 'deeplabv3plus-' + backbone

    # Use the following optimizer
    optimizer = optim.SGD(net.parameters(), lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd'])
    p['optimizer'] = str(optimizer)

    # Use the following loss function
    criterion = utils.cross_entropy2d
else:
    raise NotImplementedError

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#criterion = criterion.to(device) #TODO: IS THIS NEEDED?
Exemplo n.º 2
0
    run_id = int(runs[-1].split('_')[-1]) if runs else 0
else:
    runs = sorted(glob.glob(os.path.join(save_dir_root, 'run', 'run_*')))
    run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0

save_dir = os.path.join(save_dir_root, 'run', 'run_' + str(run_id))

# Network definition
if backbone == 'xception':
    net = deeplab_xception.DeepLabv3_plus(nInputChannels=3,
                                          n_classes=19,
                                          os=16,
                                          pretrained=True)
elif backbone == 'resnet':
    net = deeplab_resnet.DeepLabv3_plus(nInputChannels=3,
                                        n_classes=19,
                                        os=16,
                                        pretrained=True)
else:
    raise NotImplementedError
modelName = 'deeplabv3plus-' + backbone + '-cityscapes'
criterion = utils.cross_entropy2d

if resume_epoch == 0:
    print("Training deeplabv3+ from scratch...")
else:
    print("Initializing weights from: {}...".format(
        os.path.join(save_dir, 'models',
                     modelName + '_epoch-' + str(resume_epoch - 1) + '.pth')))
    net.load_state_dict(
        torch.load(os.path.join(
            save_dir, 'models',