def load(self, model, pretrain_file):  # pre-trained model loading

        print('Loading the pretrained model from', pretrain_file)

        if pretrain_file.endswith('.ckpt'):  # checkpoint file in tensorflow
            checkpoint.load_model(model.transformer, pretrain_file)
        else:
            model.transformer.load_state_dict({
                key[12:]: value
                for key, value in torch.load(pretrain_file).items()
                if key.startswith('transformer')
            })
Example #2
0
    def load(self, model_file, pretrain_file):
        """ load saved model or pretrained transformer (a part of model) """
        if model_file:
            print('Loading the model from', model_file)
            self.model.load_state_dict(torch.load(model_file, map_location=self.device), strict=False)

        elif pretrain_file: # use pretrained transformer
            print('Loading the pretrained model from', pretrain_file)
            if pretrain_file.endswith('.ckpt'): # checkpoint file in tensorflow
                checkpoint.load_model(self.model.transformer, pretrain_file)
            elif pretrain_file.endswith('.pt'): # pretrain model file in pytorch
                self.model.transformer.load_state_dict({
                    key[12:]: value for key, value in torch.load(pretrain_file).items()
                    if key.startswith('transformer')
                }) # load only transformer parts
Example #3
0
    def load(self, model_file, pretrain_file):
        """ load saved model or pretrained transformer (a part of model) """
        if pretrain_file:  # use pretrained transformer
            if pretrain_file.endswith(
                    '.ckpt'):  # checkpoint file in tensorflow
                checkpoint.load_model(self.model.transformer, pretrain_file)
            elif pretrain_file.endswith(
                    '.pt'):  # pretrain model file in pytorch
                self.model.transformer.load_state_dict({
                    key[12:]: value
                    for key, value in torch.load(pretrain_file).items()
                    if key.startswith('transformer')
                })  # load only transformer parts

        if model_file:
            self.model.load_state_dict(torch.load(model_file))
Example #4
0
def main():
    args = parse_args(parameters, "video detection",
                      "video evaluation parameters")
    print(args)

    model, encoder, model_args = load_model(args.model)
    print("model parameters:")
    print(model_args)

    classes = model_args.dataset.classes

    frames, info = cv.video_capture(args.input)
    print("Input video", info)

    scale = args.scale or 1
    size = (int(info.size[0] * scale), int(info.size[1] * scale))

    evaluate_image = initialise(args.model, model, encoder, size, args.backend)
    evaluate_video(frames,
                   evaluate_image,
                   size,
                   args,
                   classes=classes,
                   fps=info.fps)
    def load(self, model, pretrain_file):  # pre-trained model loading

        print('Loading the pretrained model from', pretrain_file)

        if pretrain_file.endswith('.ckpt'):  # checkpoint file in tensorflow
            checkpoint.load_model(model.transformer, pretrain_file)