Ejemplo n.º 1
0
def setup_net(snapshot):
    """Quickly create a network for the given snapshot.
    
    Arguments:
        snapshot {string} -- Input snapshot, IE. kitti_best.pth
    
    Returns:
        [net, transform] -- PyTorch model & the image transform function.
    """
    cudnn.benchmark = False
    torch.cuda.empty_cache()

    args = Args('./save', 'network.deepv3.DeepWV3Plus', snapshot)

    assert_and_infer_cfg(args, train_mode=False)
    # get net
    net = network.get_net(args, criterion=None)
    net = torch.nn.DataParallel(net).cuda()
    print('Net built.')

    net, _ = restore_snapshot(net,
                              optimizer=None,
                              snapshot=snapshot,
                              restore_optimizer_bool=False)
    net.eval()
    print('Net restored.')

    # get data
    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    img_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(*mean_std)])

    return net, img_transform, args
Ejemplo n.º 2
0
def citymode():
    model = deepv3.DeepWV3Plus(21)

    model_path = rep_path + '/pretrained_models/cityscapes_best.pth'
    model, _ = restore_snapshot(model,
                                optimizer=None,
                                snapshot=model_path,
                                restore_optimizer_bool=False)
    model.cpu()

    return model
Ejemplo n.º 3
0
def get_net():
    """
    Get Network for evaluation
    """
    logging.info('Load model file: %s', args.snapshot)
    net = network.get_net_ori(args, criterion=None)

    net = torch.nn.DataParallel(net).cuda()
    net, _ = restore_snapshot(net, optimizer=None,
                              snapshot=args.snapshot, snapshot2=args.snapshot2, restore_optimizer_bool=False)
    net.eval()
    return net
Ejemplo n.º 4
0
def get_net(optimizer=None, criterion=None):
    """
    Get network for train
    :return:
    """
    net = network.get_net(args, criterion=criterion)
    net, _ = restore_snapshot(net,
                              optimizer=optimizer,
                              snapshot=args.snapshot,
                              restore_optimizer_bool=False)
    #net.train()
    return net
Ejemplo n.º 5
0
def get_net():
    """
    Get Network for evaluation
    """
    logging.info('Load model file: %s', args.snapshot)
    net = network.get_net(args, criterion=None)
    net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
    net = network.warp_network_in_dataparallel(net, args.local_rank)
    net, _, _, _, _ = restore_snapshot(net, optimizer=None, scheduler=None,
                              snapshot=args.snapshot, restore_optimizer_bool=False)

    net.eval()
    return net
def get_net():
    """
    Get Network for evaluation
    """
    logging.info('Load model file: %s', args.snapshot)
    net = network.get_net(args, criterion=None)
    if args.inference_mode == 'pooling':
        net = MyDataParallel(net, gather=False).cuda()
    else:
        net = torch.nn.DataParallel(net).cuda()
    net, _ = restore_snapshot(net, optimizer=None,
                              snapshot=args.snapshot, restore_optimizer_bool=False)
    net.eval()
    return net
Ejemplo n.º 7
0
 def get_segmentation(self):
     # Get Segmentation Net
     assert_and_infer_cfg(self.opt, train_mode=False)
     self.opt.dataset_cls = cityscapes
     net = network.get_net(self.opt, criterion=None)
     net = torch.nn.DataParallel(net).cuda()
     print('Segmentation Net Built.')
     snapshot = os.path.join(os.getcwd(), os.path.dirname(__file__),
                             self.opt.snapshot)
     self.seg_net, _ = restore_snapshot(net,
                                        optimizer=None,
                                        snapshot=snapshot,
                                        restore_optimizer_bool=False)
     self.seg_net.eval()
     print('Segmentation Net Restored.')
Ejemplo n.º 8
0
parser = argparse.ArgumentParser(description='demo')
parser.add_argument('--demo-image', type=str, default='', help='path to demo image', required=True)
parser.add_argument('--snapshot', type=str, default='./pretrained_models/cityscapes_best_wideresnet38.pth', help='pre-trained checkpoint', required=True)
parser.add_argument('--arch', type=str, default='network.deepv3.DeepWV3Plus', help='network architecture used for inference')
parser.add_argument('--save-dir', type=str, default='./save', help='path to save your results')
args = parser.parse_args()
assert_and_infer_cfg(args, train_mode=False)
cudnn.benchmark = False
torch.cuda.empty_cache()

# get net
args.dataset_cls = cityscapes
net = network.get_net(args, criterion=None)
net = torch.nn.DataParallel(net).cuda()
print('Net built.')
net, _ = restore_snapshot(net, optimizer=None, snapshot=args.snapshot, restore_optimizer_bool=False)
net.eval()
print('Net restored.')

# get data
mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(*mean_std)])
img = Image.open(args.demo_image).convert('RGB')
img_tensor = img_transform(img)

# predict
with torch.no_grad():
    img = img_tensor.unsqueeze(0).cuda()
    pred = net(img)
    print('Inference done.')
Ejemplo n.º 9
0
def convert_segmentation_model(model_name='segmentation.onnx'):

    assert_and_infer_cfg(opt, train_mode=False)
    cudnn.benchmark = False
    torch.cuda.empty_cache()

    # Get segmentation Net
    opt.dataset_cls = cityscapes
    net = network.get_net(opt, criterion=None)
    net = torch.nn.DataParallel(net).cuda()
    print('Segmentation Net built.')
    net, _ = restore_snapshot(net,
                              optimizer=None,
                              snapshot=opt.snapshot,
                              restore_optimizer_bool=False)
    net.eval()
    print('Segmentation Net Restored.')

    # Input to the model
    batch_size = 1

    x = torch.randn(batch_size, 3, 1024, 2048, requires_grad=True).cuda()
    torch_out = net(x)

    # Export the model
    torch.onnx.export(
        net.module,  # model being run
        x,  # model input (or a tuple for multiple inputs)
        model_name,  # where to save the model (can be a file or file-like object)
        export_params=
        True,  # store the trained parameter weights inside the model file
        opset_version=11,  # the ONNX version to export the model to
        do_constant_folding=
        True,  # whether to execute constant folding for optimization
        input_names=['input'],  # the model's input names
        output_names=['output'],  # the model's output names
        dynamic_axes={
            'input': {
                0: 'batch_size'
            },  # variable lenght axes
            'output': {
                0: 'batch_size'
            }
        })

    ort_session = onnxruntime.InferenceSession(model_name)

    def to_numpy(tensor):
        return tensor.detach().cpu().numpy(
        ) if tensor.requires_grad else tensor.cpu().numpy()

    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
    ort_outs = ort_session.run(None, ort_inputs)

    # compare ONNX Runtime and PyTorch results
    np.testing.assert_allclose(to_numpy(torch_out),
                               ort_outs[0],
                               rtol=1e-03,
                               atol=1e-03)

    print(
        "Exported model has been tested with ONNXRuntime, and the result looks good!"
    )