def __init__(self, model: str = None): log.info(model) torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") log.info(torch_device) if model is None: model = "t5" self.modelName = model # path to all the files that will be used for inference self.path = f"./app/api/{model}/" self.model_path = self.path + "pytorch_model.bin" self.config_path = self.path + "config.json" # Selecting the correct model based on the passed madel input. Default t5 if model == "t5": self.config = T5Config.from_json_file(self.config_path) self.model = T5ForConditionalGeneration(self.config) self.tokenizer = T5Tokenizer.from_pretrained(self.path) self.model.eval() self.model.load_state_dict(torch.load(self.model_path, map_location=torch_device)) elif model == "google/pegasus-newsroom": self.config = PegasusConfig.from_json_file(self.config_path) # self.model = PegasusForConditionalGeneration(self.config) # self.tokenizer = PegasusTokenizer.from_pretrained(self.path) self.model = PegasusForConditionalGeneration.from_pretrained(model).to(torch_device) self.tokenizer = PegasusTokenizer.from_pretrained(model) elif model == "facebook/bart-large-cnn": self.config = BartConfig.from_json_file(self.config_path) # self.model = PegasusForConditionalGeneration(self.config) # self.tokenizer = PegasusTokenizer.from_pretrained(self.path) self.model = BartForConditionalGeneration.from_pretrained(model).to(torch_device) self.tokenizer = BartTokenizer.from_pretrained(model) else: raise Exception("This model is not supported") self.text = str()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): # Initialise PyTorch model config = T5Config.from_json_file(config_file) print("Building PyTorch model from configuration: {}".format(str(config))) model = T5Model(config) # Load weights from tf checkpoint load_tf_weights_in_t5(model, config, tf_checkpoint_path) # Save pytorch-model print("Save PyTorch model to {}".format(pytorch_dump_path)) torch.save(model.state_dict(), pytorch_dump_path)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): # Initialise PyTorch model config = T5Config.from_json_file(config_file) print(f"Building PyTorch model from configuration: {config}") model = T5ForConditionalGeneration(config) # Load weights from tf checkpoint load_tf_weights_in_t5(model, config, tf_checkpoint_path) # Save pytorch-model print(f"Save PyTorch model to {pytorch_dump_path}") model.save_pretrained(pytorch_dump_path)
def get_model(tokenizer_len=None): if args.mode == 'train' or args.mode == 'test_without_train': model = T5ForConditionalGeneration.from_pretrained( args.t5_model, cache_dir=args.cache_dir) if tokenizer_len is not None: model.resize_token_embeddings(tokenizer_len) elif args.mode == 'test' or args.mode == 'continue_train': model = T5ForConditionalGeneration( T5Config.from_json_file(output_config_file)) model.load_state_dict(torch.load(output_model_file)) else: raise NotImplementedError( f'No such mode called {args.mode}, error raised from get_model.') if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) return model.to(device)
def custom_init(self): self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.tokenizer = T5Tokenizer.from_pretrained( "/usr/src/WHOA-FAQ-Answer-Project/WHO-FAQ-Search-Engine/variation_generation/models/" ) config = T5Config.from_json_file( '/usr/src/WHOA-FAQ-Answer-Project/WHO-FAQ-Search-Engine/variation_generation/T5config.json' ) # TODO : Add model weight download # self.model = torch.load(path, map_location=self.device) self.model = T5ForConditionalGeneration.from_pretrained(\ self.path, from_tf=True, config=config) self.model.to(self.device) self.model.eval() self.max_length = self.max_length self.num_variations = self.num_variations self.initialised = True
def __init__(self, model: str = None, service: str = "summ"): if model is None: model = "t5" # path to all the files that will be used for inference self.path = f"./{service}/{model}/" self.model_path = self.path + "model.bin" self.config_path = self.path + "config.json" # Selecting the correct model based on the passed madel input. Default t5 if model == "t5": self.config = T5Config.from_json_file(self.config_path) self.model = T5ForConditionalGeneration(self.config) self.tokenizer = T5Tokenizer.from_pretrained(self.path) else: raise Exception("This model is not supported") self.model.eval() self.model.load_state_dict( torch.load(self.model_path, map_location=device)) self.text = str()
"WARNING: e2e is meant to generate questions by context. The ouput of the script will be a csv instead of a json." ) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("Device:", device) model_created = False print("Loading model and tokenizer...", end="", flush=True) if args.checkpoint != None: model_created = True if args.bart: config = BartConfig.from_json_file(args.checkpoint + "/config.json") model = BartForConditionalGeneration.from_pretrained( args.checkpoint + "/pytorch_model.bin", config=config) if args.t5: config = T5Config.from_json_file(args.checkpoint + "/config.json") model = T5ForConditionalGeneration.from_pretrained( args.checkpoint + "/pytorch_model.bin", config=config) elif not args.bart and not args.t5: config = EncoderDecoderConfig.from_json_file(args.checkpoint + "/config.json") model = EncoderDecoderModel.from_pretrained(args.checkpoint + "/pytorch_model.bin", config=config) model_name = args.checkpoint if args.bart: if args.checkpoint == None: model_name = "WikinewsSum/bart-large-multi-fr-wiki-news" if args.model_name == "" else args.model_name tokenizer = BartTokenizer.from_pretrained( args.tokenizer