예제 #1
0
 def _load_t5_model(self, model_name_or_path, config, cache_dir):
     """Loads the encoder model from T5"""
     from transformers import T5EncoderModel
     T5EncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"]
     self.auto_model = T5EncoderModel.from_pretrained(model_name_or_path,
                                                      config=config,
                                                      cache_dir=cache_dir)
예제 #2
0
    def __init__(self, hparams):
        super().__init__()

        #Parameters stored in dictionary
        self.hparams = hparams

        #Tokenizer for decoding sentences
        self.tokenizer = T5Tokenizer.from_pretrained(self.hparams.t5_model)

        #Decoder -> Decode image embedding combined with the last hidden state of the encoder
        self.decoder = T5ForConditionalGeneration.from_pretrained(
            self.hparams.t5_model)

        #Sentence encoder -> just transformer encoder for questions
        if self.hparams.same_enc:
            self.sentence_encoder = self.decoder.get_encoder()
        else:
            self.sentence_encoder = T5EncoderModel.from_pretrained(
                self.hparams.t5_model)

        # Feature adapter for combining image features and transformer
        # last hidden state from transformer encoder (question)
        # hidden_dim needs to be manually setted
        if not self.hparams.concat_only:
            self.adapter = nn.Linear(self.hparams.hidden_dim,
                                     self.hparams.seq_len)

        #to align the channel number with transformer's decoder
        self.CNNEmbedder = nn.Sequential(
            ConvBlock(3, 16), ConvBlock(16, 64), ConvBlock(64, 256),
            ConvBlock(256, self.decoder.config.d_model))

        self.img_shape = (hparams.img_h, hparams.img_w)
        self.sync_dist = self.hparams.gpus > 1
예제 #3
0
    def get_model(self) -> Union[T5Model, T5EncoderModel]:

        if not self._decoder:
            if self._half_precision_model:
                model = T5EncoderModel.from_pretrained(
                    self._model_directory, torch_dtype=torch.float16)
            else:
                model = T5EncoderModel.from_pretrained(self._model_directory)
        else:
            if self._half_precision_model:
                model = T5Model.from_pretrained(self._model_directory,
                                                torch_dtype=torch.float16)
            else:
                model = T5Model.from_pretrained(self._model_directory)

        return model
예제 #4
0
 def get_model(self) -> Union[T5Model, T5EncoderModel]:
     if not self._decoder:
         model = T5EncoderModel.from_pretrained(self._model_directory)
     else:
         model = T5Model.from_pretrained(self._model_directory)
     # Compute in half precision, saving us half the memory
     if self._half_precision_model:
         model = model.half()
     return model
예제 #5
0
 def __init__(self, hparams):
     super().__init__()
     
     #Parameters stored in dictionary
     self.hparams = hparams
     
     #Tokenizer for decoding sentences
     self.tokenizer = T5Tokenizer.from_pretrained(self.hparams.t5_model)
     
     #Decoder -> Decode image embedding combined with the last hidden state of the encoder
     self.decoder = T5ForConditionalGeneration.from_pretrained(self.hparams.t5_model)
     
     #Sentence encoder -> just transformer encoder for questions
     if self.hparams.same_enc:
         self.sentence_encoder = self.decoder.get_encoder()
     else:
         self.sentence_encoder = T5EncoderModel.from_pretrained(self.hparams.t5_model)
     
     self.sync_dist = self.hparams.gpus > 1
def similarity_oracle():
    torch.cuda.set_device(1)
    model = T5EncoderModel.from_pretrained('t5-small').cuda()
#    model = UnsupervisedDenoiseT5.from_pretrained('output_dirs/uns_denoise_debug6').encoder.cuda()
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained('t5-small')
    dataset = load_dataset('cnn_dailymail', '3.0.0')
    val_dataset = dataset['validation']
    ids = val_dataset['id']
    articles = val_dataset['article']
    sentences = _get_val_article_sentences()
    for i in ids:
        np.random.shuffle(sentences[i])
    sep_token = "</s>"
    sep_token_id = 1
    inputs = [f" {sep_token} ".join(sentences[i]) for i in ids]

    model_input = tokenizer(inputs, max_length=1024, padding="max_length", truncation=True)
    sentence_indicator = _create_sentence_indicator(model_input['input_ids'], tokenizer, sep_token_id)

    d = _create_sentence_embeddings(model, ids, model_input, sentence_indicator)
    torch.save(d, 'val_sentence_embeddings.pt')
예제 #7
0
def get_model(model_type, num_layers, all_layers=None):
    print(f"Initializing model from : {model_type}")
    if model_type.startswith("scibert"):
        model = AutoModel.from_pretrained(cache_scibert(model_type))
    elif "t5" in model_type:
        from transformers import T5EncoderModel

        model = T5EncoderModel.from_pretrained(model_type)
    else:
        model = AutoModel.from_pretrained(model_type)
    model.eval()

    if hasattr(model, "decoder") and hasattr(model, "encoder"):
        model = model.encoder

    # drop unused layers
    if not all_layers:
        if hasattr(model, "n_layers"):  # xlm
            assert (
                0 <= num_layers <= model.n_layers
            ), f"Invalid num_layers: num_layers should be between 0 and {model.n_layers} for {model_type}"
            model.n_layers = num_layers
        elif hasattr(model, "layer"):  # xlnet
            assert (
                0 <= num_layers <= len(model.layer)
            ), f"Invalid num_layers: num_layers should be between 0 and {len(model.layer)} for {model_type}"
            model.layer = torch.nn.ModuleList(
                [layer for layer in model.layer[:num_layers]])
        elif hasattr(model, "encoder"):  # albert
            if hasattr(model.encoder, "albert_layer_groups"):
                assert (
                    0 <= num_layers <= model.encoder.config.num_hidden_layers
                ), f"Invalid num_layers: num_layers should be between 0 and {model.encoder.config.num_hidden_layers} for {model_type}"
                model.encoder.config.num_hidden_layers = num_layers
            elif hasattr(model.encoder, "block"):  # t5
                assert (
                    0 <= num_layers <= len(model.encoder.block)
                ), f"Invalid num_layers: num_layers should be between 0 and {len(model.encoder.block)} for {model_type}"
                model.encoder.block = torch.nn.ModuleList(
                    [layer for layer in model.encoder.block[:num_layers]])
            else:  # bert, roberta
                assert (
                    0 <= num_layers <= len(model.encoder.layer)
                ), f"Invalid num_layers: num_layers should be between 0 and {len(model.encoder.layer)} for {model_type}"
                model.encoder.layer = torch.nn.ModuleList(
                    [layer for layer in model.encoder.layer[:num_layers]])
        elif hasattr(model, "transformer"):  # bert, roberta
            assert (
                0 <= num_layers <= len(model.transformer.layer)
            ), f"Invalid num_layers: num_layers should be between 0 and {len(model.transformer.layer)} for {model_type}"
            model.transformer.layer = torch.nn.ModuleList(
                [layer for layer in model.transformer.layer[:num_layers]])
        elif hasattr(model, "layers"):  # bart
            assert (
                0 <= num_layers <= len(model.layers)
            ), f"Invalid num_layers: num_layers should be between 0 and {len(model.layers)} for {model_type}"
            model.layers = torch.nn.ModuleList(
                [layer for layer in model.layers[:num_layers]])
        else:
            raise ValueError("Not supported")
    else:
        if hasattr(model, "output_hidden_states"):
            model.output_hidden_states = True
        elif hasattr(model, "encoder"):
            model.encoder.output_hidden_states = True
        elif hasattr(model, "transformer"):
            model.transformer.output_hidden_states = True
        # else:
        #     raise ValueError(f"Not supported model architecture: {model_type}")

    return model
예제 #8
0
def prepare_model(config, bert_model_name_or_path=None):
    args = config['args']
    emb_non_trainable = not args.embedding_trainable
    labels = load_label(args.label_path)
    label_size = len(labels)
    config['labels'] = labels
    # prepare model
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'gnb':
            model = TextGloveGNB(config, args.embedding_path, label_size)
        if config['enc_class'] == 'cnn':
            model = TextGloveCNN(config,
                                 args.embedding_path,
                                 label_size,
                                 emb_non_trainable=emb_non_trainable)
        if config['enc_class'] == 'densenet-cnn':
            model = TextGloveDensenetCNN(config,
                                         args.embedding_path,
                                         label_size,
                                         emb_non_trainable=emb_non_trainable)
        if config['enc_class'] == 'densenet-dsa':
            model = TextGloveDensenetDSA(config,
                                         args.embedding_path,
                                         label_size,
                                         emb_non_trainable=emb_non_trainable)
    else:
        model_name_or_path = args.bert_model_name_or_path
        if bert_model_name_or_path:
            model_name_or_path = bert_model_name_or_path

        if config['emb_class'] == 'bart' and config['use_kobart']:
            from transformers import BartModel
            from kobart import get_kobart_tokenizer, get_pytorch_kobart_model
            bert_tokenizer = get_kobart_tokenizer()
            bert_tokenizer.cls_token = '<s>'
            bert_tokenizer.sep_token = '</s>'
            bert_tokenizer.pad_token = '<pad>'
            bert_model = BartModel.from_pretrained(get_pytorch_kobart_model())
        elif config['emb_class'] in ['gpt']:
            bert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
            bert_tokenizer.bos_token = '<|startoftext|>'
            bert_tokenizer.eos_token = '<|endoftext|>'
            bert_tokenizer.cls_token = '<|startoftext|>'
            bert_tokenizer.sep_token = '<|endoftext|>'
            bert_tokenizer.pad_token = '<|pad|>'
            bert_model = AutoModel.from_pretrained(
                model_name_or_path,
                from_tf=bool(".ckpt" in model_name_or_path))
            # 3 new tokens added
            bert_model.resize_token_embeddings(len(bert_tokenizer))
        elif config['emb_class'] in ['t5']:
            from transformers import T5EncoderModel
            bert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
            bert_tokenizer.cls_token = '<s>'
            bert_tokenizer.sep_token = '</s>'
            bert_tokenizer.pad_token = '<pad>'
            bert_model = T5EncoderModel.from_pretrained(
                model_name_or_path,
                from_tf=bool(".ckpt" in model_name_or_path))

        else:
            bert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
            bert_model = AutoModel.from_pretrained(
                model_name_or_path,
                from_tf=bool(".ckpt" in model_name_or_path))

        bert_config = bert_model.config
        # bert model reduction
        reduce_bert_model(config, bert_model, bert_config)
        ModelClass = TextBertCNN
        if config['enc_class'] == 'cls': ModelClass = TextBertCLS
        if config['enc_class'] == 'densenet-cnn':
            ModelClass = TextBertDensenetCNN

        model = ModelClass(config,
                           bert_config,
                           bert_model,
                           bert_tokenizer,
                           label_size,
                           feature_based=args.bert_use_feature_based,
                           finetune_last=args.bert_use_finetune_last)
    if args.restore_path:
        checkpoint = load_checkpoint(args.restore_path)
        model.load_state_dict(checkpoint)
    if args.enable_qat:
        model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        '''
        # fuse if applicable
        # model = torch.quantization.fuse_modules(model, [['']])
        '''
        model = torch.quantization.prepare_qat(model)
    if args.enable_qat_fx:
        import torch.quantization.quantize_fx as quantize_fx
        model.train()
        qconfig_dict = {
            "": torch.quantization.get_default_qat_qconfig('fbgemm')
        }
        model = quantize_fx.prepare_qat_fx(model, qconfig_dict)

    logger.info("[model] :\n{}".format(model.__str__()))
    logger.info("[model prepared]")
    return model