Esempio n. 1
0
def create(name, pretrained, channels, classes):
    """Creates a specified YOLOv5 model

    Arguments:
        name (str): name of model, i.e. 'yolov5s'
        pretrained (bool): load pretrained weights into the model
        channels (int): number of input channels
        classes (int): number of model classes

    Returns:
        pytorch model
    """
    config = os.path.join(os.path.dirname(__file__), 'models', '%s.yaml' % name)  # model.yaml path
    try:
        model = Model(config, channels, classes)
        if pretrained:
            ckpt = '%s.pt' % name  # checkpoint filename
            attempt_download(ckpt)  # download if not found locally
            state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict()  # to FP32
            state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape}  # filter
            model.load_state_dict(state_dict, strict=False)  # load

            m = NMS()
            m.f = -1  # from
            m.i = model.model[-1].i + 1  # index
            model.model.add_module(name='%s' % m.i, module=m)  # add NMS
            model.eval()
        return model

    except Exception as e:
        help_url = 'https://github.com/ultralytics/yolov5/issues/36'
        s = 'Cache maybe be out of date, deleting cache and retrying may solve this. See %s for help.' % help_url
        raise Exception(s) from e
Esempio n. 2
0
 def add_nms(self):  # fuse model Conv2d() + BatchNorm2d() layers
     if type(self.model[-1]) is not NMS:  # if missing NMS
         print('Adding NMS module... ')
         m = NMS()  # module
         m.f = -1  # from
         m.i = self.model[-1].i + 1  # index
         self.model.add_module(name='%s' % m.i, module=m)  # add
     return self
Esempio n. 3
0
 def nms(self, mode=True):  # add or remove NMS module
     present = type(self.model[-1]) is NMS  # last layer is NMS
     if mode and not present:
         print('Adding NMS... ')
         m = NMS()  # module
         m.f = -1  # from
         m.i = self.model[-1].i + 1  # index
         self.model.add_module(name='%s' % m.i, module=m)  # add
         self.eval()
     elif not mode and present:
         print('Removing NMS... ')
         self.model = self.model[:-1]  # remove
     return self