Пример #1
0
def load_model(args, verbose=False):
    if args.command != 'train' and not os.path.isfile(args.model):
        raise RuntimeError('Model file {} does not exist!'.format(args.model))

    model = None
    state = {}
    _, ext = os.path.splitext(args.model)

    if args.command == 'train' and (not os.path.exists(args.model) or args.override):
        if verbose: print('Initializing model...')
        model = Model(args.backbone, args.classes)
        model.initialize(args.fine_tune)
        if verbose: print(model)

    elif ext == '.pth' or ext == '.torch':
        if verbose: print('Loading model from {}...'.format(os.path.basename(args.model)))
        model, state = Model.load(args.model)
        if verbose: print(model)

    elif args.command == 'infer' and ext in ['.engine', '.plan']:
        model = None
    
    else:
        raise RuntimeError('Invalid model format "{}"!'.format(args.ext))

    state['path'] = args.model
    return model, state
Пример #2
0
def load_model(args, verbose=False):
    if not os.path.isfile(args.config_file):
        raise RuntimeError('Config file {} does not exist!'.format(
            args.config_file))
    cfg.merge_from_file(args.config_file)
    cfg.freeze()

    model = None
    state = {}
    _, ext = os.path.splitext(cfg.MODEL.WEIGHT)
    print('####ext:', ext)

    if ext == '.pth' or ext == '.torch':
        if verbose:
            print('Loading model from {}...'.format(
                os.path.basename(cfg.MODEL.WEIGHT)))
            #print('***********Cfg:',cfg)
            #print('#####*cfg.MODEL.WEIGHT:',cfg.MODEL.WEIGHT)

        model = Model.load(cfg)
        if verbose:
            print(model)
    elif ext in ['.engine', '.plan']:
        model = None

    else:
        raise RuntimeError('Invalid model format "{}"!'.format(args.ext))

    #state = cfg.MODEL.WEIGHT
    state['path'] = cfg.MODEL.WEIGHT
    return model, state
Пример #3
0
def load_model(args, verbose=False):
    if not os.path.isfile(args.config_file):
        raise RuntimeError('Config file {} does not exist!'.format(
            args.config_file))
    cfg.merge_from_file(args.config_file)
    cfg.freeze()

    if verbose:
        print('Loading model from {}...'.format(
            os.path.basename(cfg.MODEL.WEIGHT)))
    model = Model.load(cfg)
    if verbose:
        print(model)

    state = cfg.MODEL.WEIGHT
    return model, state
Пример #4
0
def load_model(args, verbose=False):
    config = {}

    if args.config:
        with open(args.config, 'r') as config_file:
            config = json.load(config_file)

    if args.command != 'train' and not os.path.isfile(args.model):
        raise RuntimeError('Model file {} does not exist!'.format(args.model))

    model = None
    state = {}
    model_name, ext = os.path.splitext(args.model)

    if args.command == 'train' and (not os.path.exists(args.model)
                                    or args.override):
        if verbose: print('Initializing model...')
        model = Model(backbones=args.backbone,
                      classes=args.classes,
                      rotated_bbox=args.rotated_bbox,
                      config=config)
        model.initialize(args.fine_tune)
        if verbose: print(model)

    elif ext == '.pth' or ext == '.torch':
        if verbose:
            print('Loading model from {}...'.format(
                os.path.basename(args.model)))

        exporting = False
        if args.command == 'eval':
            exporting = True

        model, state = Model.load(filename=args.model,
                                  rotated_bbox=args.rotated_bbox,
                                  config=config,
                                  exporting=exporting)
        if verbose: print(model)

    elif args.command == 'infer' and ext in ['.engine', '.plan']:
        model = None

    else:
        raise RuntimeError('Invalid model format "{}"!'.format(args.ext))

    state['path'] = model_name
    return model, state
Пример #5
0
    ort_outs = ort_sess.run(None, ort_inputs)
    if outputs_flatten is not None:
        print("== Checking model output ==")
        [
            np.testing.assert_allclose(to_numpy(output),
                                       ort_outs[i],
                                       rtol=1e-03,
                                       atol=1e-05)
            for i, output in enumerate(outputs_flatten)
        ]
    print("== Done ==")


# Download pretrained model from:
# https://github.com/NVIDIA/retinanet-examples/releases/tag/19.04
model, state = Model.load('retinanet_rn101fpn/retinanet_rn101fpn.pth')
model.eval()
model.exporting = True
input_image = Image.open(filename)
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224,
                                                          0.225]),
])
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)
output = torch_inference(model, input_tensor)

# Test exported model with TensorProto data saved in files
inputs_flatten = flatten(input_tensor.detach().cpu().numpy())
inputs_flatten = update_flatten_list(inputs_flatten, [])