예제 #1
0
    def __init__(self, model: str = None):
        log.info(model)
        torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        log.info(torch_device)
        if model is None:
            model = "t5"
        self.modelName = model
        # path to all the files that will be used for inference
        self.path = f"./app/api/{model}/"
        self.model_path = self.path + "pytorch_model.bin"
        self.config_path = self.path + "config.json"

        # Selecting the correct model based on the passed madel input. Default t5
        if model == "t5":
            self.config = T5Config.from_json_file(self.config_path)
            self.model = T5ForConditionalGeneration(self.config)
            self.tokenizer = T5Tokenizer.from_pretrained(self.path)
            self.model.eval()
            self.model.load_state_dict(torch.load(self.model_path, map_location=torch_device))
        elif model == "google/pegasus-newsroom":
            self.config = PegasusConfig.from_json_file(self.config_path)
            # self.model = PegasusForConditionalGeneration(self.config)
            # self.tokenizer = PegasusTokenizer.from_pretrained(self.path)
            self.model = PegasusForConditionalGeneration.from_pretrained(model).to(torch_device)
            self.tokenizer = PegasusTokenizer.from_pretrained(model)
        elif model == "facebook/bart-large-cnn":
            self.config = BartConfig.from_json_file(self.config_path)
            # self.model = PegasusForConditionalGeneration(self.config)
            # self.tokenizer = PegasusTokenizer.from_pretrained(self.path)
            self.model = BartForConditionalGeneration.from_pretrained(model).to(torch_device)
            self.tokenizer = BartTokenizer.from_pretrained(model)
        else:
            raise Exception("This model is not supported")

        self.text = str()
예제 #2
0
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = T5Config.from_json_file(config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = T5Model(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_t5(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)
예제 #3
0
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file,
                                     pytorch_dump_path):
    # Initialise PyTorch model
    config = T5Config.from_json_file(config_file)
    print(f"Building PyTorch model from configuration: {config}")
    model = T5ForConditionalGeneration(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_t5(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print(f"Save PyTorch model to {pytorch_dump_path}")
    model.save_pretrained(pytorch_dump_path)
def get_model(tokenizer_len=None):
  if args.mode == 'train' or args.mode == 'test_without_train':
    model = T5ForConditionalGeneration.from_pretrained(
        args.t5_model, cache_dir=args.cache_dir)
    if tokenizer_len is not None:
      model.resize_token_embeddings(tokenizer_len)
  elif args.mode == 'test' or args.mode == 'continue_train':
    model = T5ForConditionalGeneration(
        T5Config.from_json_file(output_config_file))
    model.load_state_dict(torch.load(output_model_file))
  else:
    raise NotImplementedError(
        f'No such mode called {args.mode}, error raised from get_model.')

  if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)
  return model.to(device)
    def custom_init(self):
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        self.tokenizer = T5Tokenizer.from_pretrained(
            "/usr/src/WHOA-FAQ-Answer-Project/WHO-FAQ-Search-Engine/variation_generation/models/"
        )
        config = T5Config.from_json_file(
            '/usr/src/WHOA-FAQ-Answer-Project/WHO-FAQ-Search-Engine/variation_generation/T5config.json'
        )

        # TODO : Add model weight download
        # self.model = torch.load(path, map_location=self.device)
        self.model = T5ForConditionalGeneration.from_pretrained(\
            self.path, from_tf=True, config=config)
        self.model.to(self.device)
        self.model.eval()

        self.max_length = self.max_length
        self.num_variations = self.num_variations
        self.initialised = True
예제 #6
0
    def __init__(self, model: str = None, service: str = "summ"):
        if model is None:
            model = "t5"

        # path to all the files that will be used for inference
        self.path = f"./{service}/{model}/"
        self.model_path = self.path + "model.bin"
        self.config_path = self.path + "config.json"

        # Selecting the correct model based on the passed madel input. Default t5
        if model == "t5":
            self.config = T5Config.from_json_file(self.config_path)
            self.model = T5ForConditionalGeneration(self.config)
            self.tokenizer = T5Tokenizer.from_pretrained(self.path)
        else:
            raise Exception("This model is not supported")

        self.model.eval()
        self.model.load_state_dict(
            torch.load(self.model_path, map_location=device))

        self.text = str()
예제 #7
0
            "WARNING: e2e is meant to generate questions by context. The ouput of the script will be a csv instead of a json."
        )
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device:", device)

    model_created = False
    print("Loading model and tokenizer...", end="", flush=True)
    if args.checkpoint != None:
        model_created = True
        if args.bart:
            config = BartConfig.from_json_file(args.checkpoint +
                                               "/config.json")
            model = BartForConditionalGeneration.from_pretrained(
                args.checkpoint + "/pytorch_model.bin", config=config)
        if args.t5:
            config = T5Config.from_json_file(args.checkpoint + "/config.json")
            model = T5ForConditionalGeneration.from_pretrained(
                args.checkpoint + "/pytorch_model.bin", config=config)
        elif not args.bart and not args.t5:
            config = EncoderDecoderConfig.from_json_file(args.checkpoint +
                                                         "/config.json")
            model = EncoderDecoderModel.from_pretrained(args.checkpoint +
                                                        "/pytorch_model.bin",
                                                        config=config)
        model_name = args.checkpoint

    if args.bart:
        if args.checkpoint == None:
            model_name = "WikinewsSum/bart-large-multi-fr-wiki-news" if args.model_name == "" else args.model_name
        tokenizer = BartTokenizer.from_pretrained(
            args.tokenizer