示例#1
0
文件: api.py 项目: hypothermic/RAPiD
    def __init__(self, model_name='', weights_path=None, model=None, **kwargs):
        # post-processing settings
        self.conf_thres = kwargs.get('conf_thres', None)
        self.input_size = kwargs.get('input_size', None)

        if model:
            self.model = model
            return
        if model_name == 'rapid':
            from models.rapid import RAPiD
            model = RAPiD(backbone='dark53')
        else:
            raise NotImplementedError()
        total_params = sum(p.numel() for p in model.parameters()
                           if p.requires_grad)
        print(f'Successfully initialized model {model_name}.',
              'Total number of trainable parameters:', total_params)

        model.load_state_dict(torch.load(weights_path)['model'])
        print(f'Successfully loaded weights: {weights_path}')
        model.eval()
        if kwargs.get('use_cuda', True):
            print("Using CUDA...")
            assert torch.cuda.is_available()
            self.model = model.cuda()
        else:
            print("Using CPU instead of CUDA...")
            self.model = model
示例#2
0
                                initial_size,
                                enable_aug,
                                only_person=only_person,
                                debug_mode=args.debug)
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=num_cpu,
                            pin_memory=True,
                            drop_last=False)
    dataiterator = iter(dataloader)

    if args.model == 'rapid_pL1':
        from models.rapid import RAPiD
        model = RAPiD(backbone=args.backbone,
                      img_norm=False,
                      loss_angle='period_L1')
    elif args.model == 'rapid_pL2':
        from models.rapid import RAPiD
        model = RAPiD(backbone=args.backbone,
                      img_norm=False,
                      loss_angle='period_L2')

    model = model.cuda()

    start_iter = -1
    if args.checkpoint:
        print("loading ckpt...", args.checkpoint)
        weights_path = os.path.join('./weights/', args.checkpoint)
        state = torch.load(weights_path)
        model.load_state_dict(state['model'])