def init_data(self, use_cuda: bool):
            test_device = torch.device('cuda:0') if use_cuda else \
                torch.device('cpu:0')

            torch.set_grad_enabled(False)
            cfg = BertConfig()
            self.torch_embedding = BertEmbeddings(cfg)

            self.torch_embedding.eval()

            if use_cuda:
                self.torch_embedding.to(test_device)

            self.turbo_embedding = turbo_transformers.BertEmbeddings.from_torch(
                self.torch_embedding)

            input_ids = torch.randint(low=0,
                                      high=cfg.vocab_size - 1,
                                      size=(batch_size, seq_length),
                                      dtype=torch.long,
                                      device=test_device)
            position_ids = torch.arange(seq_length,
                                        dtype=torch.long,
                                        device=input_ids.device)

            position_ids = position_ids.repeat(batch_size, 1)
            token_type_ids = torch.zeros_like(input_ids, dtype=torch.long)

            return input_ids, position_ids, token_type_ids
    class TestBertEmbedding(unittest.TestCase):
        def init_data(self, use_cuda: bool):
            test_device = torch.device('cuda:0') if use_cuda else \
                torch.device('cpu:0')

            torch.set_grad_enabled(False)
            cfg = BertConfig()
            self.torch_embedding = BertEmbeddings(cfg)

            self.torch_embedding.eval()

            if use_cuda:
                self.torch_embedding.to(test_device)

            self.turbo_embedding = turbo_transformers.BertEmbeddings.from_torch(
                self.torch_embedding)

            input_ids = torch.randint(low=0,
                                      high=cfg.vocab_size - 1,
                                      size=(batch_size, seq_length),
                                      dtype=torch.long,
                                      device=test_device)
            position_ids = torch.arange(seq_length,
                                        dtype=torch.long,
                                        device=input_ids.device)

            position_ids = position_ids.repeat(batch_size, 1)
            token_type_ids = torch.zeros_like(input_ids, dtype=torch.long)

            return input_ids, position_ids, token_type_ids

        def check_torch_and_turbo(self, use_cuda):
            input_ids, position_ids, token_type_ids = self.init_data(use_cuda)

            device = "GPU" if use_cuda else "CPU"
            num_iter = 100
            torch_model = lambda: self.torch_embedding(
                input_ids, token_type_ids, position_ids)
            torch_result, torch_qps, torch_time = test_helper.run_model(
                torch_model, use_cuda, num_iter)
            print(f"BertEmbeddings \"({batch_size},{seq_length:03})\" ",
                  f"{device} Torch QPS,  {torch_qps}, time, {torch_time}")

            turbo_model = lambda: self.turbo_embedding(input_ids, position_ids,
                                                       token_type_ids)
            turbo_result, turbo_qps, turbo_time = test_helper.run_model(
                turbo_model, use_cuda, num_iter)
            print(f"BertEmbeddings \"({batch_size},{seq_length:03})\" ",
                  f"{device} Turbo QPS,  {turbo_qps}, time, {turbo_time}")

            self.assertTrue(
                torch.max(torch.abs(torch_result - turbo_result)) < 1e-5)

        def test_embedding(self):
            self.check_torch_and_turbo(use_cuda=False)
            if torch.cuda.is_available() and \
                turbo_transformers.config.is_compiled_with_cuda():
                self.check_torch_and_turbo(use_cuda=True)
예제 #3
0
    def __init__(self, config):
        super(BertImgModel, self).__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)#CaptionBertEncoder(config)
        self.pooler = BertPooler(config)

        self.img_dim = config.img_feature_dim
        logger.info('BertImgModel Image Dimension: {}'.format(self.img_dim))
        self.img_feature_type = config.img_feature_type
        if hasattr(config, 'use_img_layernorm'):
            self.use_img_layernorm = config.use_img_layernorm
        else:
            self.use_img_layernorm = None

        if config.img_feature_type == 'dis_code':
            self.code_embeddings = nn.Embedding(config.code_voc, config.code_dim, padding_idx=0)
            self.img_embedding = nn.Linear(config.code_dim, self.config.hidden_size, bias=True)
        elif config.img_feature_type == 'dis_code_t': # transpose
            self.code_embeddings = nn.Embedding(config.code_voc, config.code_dim, padding_idx=0)
            self.img_embedding = nn.Linear(config.code_size, self.config.hidden_size, bias=True)
        elif config.img_feature_type == 'dis_code_scale': # scaled
            self.input_embeddings = nn.Linear(config.code_dim, config.code_size, bias=True)
            self.code_embeddings = nn.Embedding(config.code_voc, config.code_dim, padding_idx=0)
            self.img_embedding = nn.Linear(config.code_dim, self.config.hidden_size, bias=True)
        else:
            self.img_embedding = nn.Linear(self.img_dim, self.config.hidden_size, bias=True)
            self.dropout = nn.Dropout(config.hidden_dropout_prob)
            if self.use_img_layernorm:
                self.LayerNorm = LayerNorm(config.hidden_size, eps=config.img_layer_norm_eps)
예제 #4
0
	def __init__(self,image_feature_dim,num_segments,num_class,fc_dim=1024):
		super(Transformermodule,self).__init__()
		self.image_feature_dim = image_feature_dim
		self.num_segments = num_segments
		
		self.configname = 'bert-base-uncased'
		#self.configname = 'roberta-base'
		
		# Load Bert Model as the transformer
		self.tokenizer = BertTokenizer.from_pretrained(self.configname)
		self.config = BertConfig.from_pretrained(self.configname)
		# The full Bert Model with 12 Layers
		# self.transformer = BertModel.from_pretrained(self.configname, config = self.config)
		
		# In the encoder architecture 8 layers are removed and there is only 4 layers.
		self.transformer = BertModel.from_pretrained(self.configname, config = self.config)
		self.transformer = self.remove_bert_layers(self.transformer, num_layers_to_keep=3)


		# Project the video embedding to the transformer embedding for processing.
		self.hidden_dim = self.transformer.config.hidden_size

		self.projection_layer = nn.Linear(image_feature_dim,self.hidden_dim)
		self.embedding_fn = BertEmbeddings(self.config)
		self.fc = nn.Sequential(nn.Linear(self.hidden_dim,fc_dim),
								nn.Dropout(0.5),
								nn.Tanh(),
								nn.Linear(fc_dim,num_class))
예제 #5
0
    def __init__(self,
                 args,
                 tokenizer: BertTokenizer,
                 object_features_variant=False,
                 positional_embed_variant=False,
                 latent_transformer=False):
        super().__init__()

        self.args = args
        self.tokenizer = tokenizer

        self.image_projection = nn.Sequential(
            nn.Linear(512, 768), nn.BatchNorm1d(768, momentum=0.01))

        config = BertConfig.from_pretrained('bert-base-uncased')
        self.tokenizer = tokenizer
        self.embeddings = BertEmbeddings(config)

        self.text_encoder = BertModel.from_pretrained("bert-base-uncased",
                                                      return_dict=True)
        self.decoder = BertLMHeadModel.from_pretrained(
            'bert-base-uncased',
            is_decoder=True,
            use_cache=True,
            add_cross_attention=True)

        if object_features_variant:
            self.image_transformer = ImageTransformerEncoder(args)

        self.positional_embed = True if positional_embed_variant else False

        self.latent_transformer = latent_transformer
예제 #6
0
 def __init__(self, config):
     super(BertModel, self).__init__(config)
     self.config = config
     self.embeddings = BertEmbeddings(config)
     self.encoder = BertEncoder(config)
     self.pooler = BertPooler(config)
     self.init_weights()
예제 #7
0
 def __init__(self, config):
     super(BojoneModel, self).__init__(config)
     self.config = config
     self.embeddings = BertEmbeddings(config)
     self.encoder = BertEncoder(config)
     self.cls = BertPreTrainingHeads(self.embeddings.word_embeddings.weight,
                                     config)
     self.init_weights()
예제 #8
0
 def __init__(self, config, tokenizer, device):
     super().__init__()
     self.config = config
     self.tokenizer = tokenizer
     self.embeddings = BertEmbeddings(self.config)
     self.corrector = BertEncoder(self.config)
     self.mask_token_id = self.tokenizer.mask_token_id
     self.cls = BertOnlyMLMHead(self.config)
     self._device = device
예제 #9
0
    def __init__(self, config, args):
        super().__init__(config)
        self.config = config

        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)
        self.MAG = MAG(config, args)

        self.init_weights()
예제 #10
0
    def _build_word_embedding(self):
        self.bert_config = BertConfig.from_pretrained(self.config.bert_model_name)
        if self.config.pretrained_bert:
            bert_model = BertForPreTraining.from_pretrained(self.config.bert_model_name)
            self.word_embedding = bert_model.bert.embeddings
            self.pooler = bert_model.bert.pooler
            self.pooler.apply(self.init_weights)

        else:
            self.pooler = BertPooler(self.bert_config)
            self.word_embedding = BertEmbeddings(self.bert_config)
예제 #11
0
    def __init__(self, config):
        super(BertImgModel, self).__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = CaptionBertEncoder(config)
        self.pooler = BertPooler(config)

        self.img_dim = config.img_feature_dim
        logger.info('BertImgModel Image Dimension: {}'.format(self.img_dim))

        # self.apply(self.init_weights)
        self.init_weights()
예제 #12
0
    def __init__(self, config: LukeConfig):
        super(LukeModel, self).__init__()

        self.config = config

        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)

        if self.config.bert_model_name and "roberta" in self.config.bert_model_name:
            self.embeddings = RobertaEmbeddings(config)
            self.embeddings.token_type_embeddings.requires_grad = False
        else:
            self.embeddings = BertEmbeddings(config)
        self.entity_embeddings = EntityEmbeddings(config)
예제 #13
0
    def __init__(self, config, add_pooling_layer=True):
        # Call the init one parent class up. Otherwise, the model will be defined twice.
        BertPreTrainedModel.__init__(self, config)
        self.config = config

        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)

        self.pooler = BertPooler(config) if add_pooling_layer else None

        # Sparsify linear modules.
        self.sparsify_model()

        self.init_weights()
예제 #14
0
def get_modules():
    params = copy.deepcopy(PARAMS_DICT)

    params["hidden_dropout_prob"] = params.pop("dropout")
    params["hidden_size"] = params.pop("embedding_size")

    # bert, roberta, electra self attentions have the same code.

    torch.manual_seed(1234)
    yield "bert", BertEmbeddings(BertConfig(**params))

    albertparams = copy.deepcopy(PARAMS_DICT)
    albertparams["hidden_dropout_prob"] = albertparams.pop("dropout")

    torch.manual_seed(1234)
    yield "albert", AlbertEmbeddings(AlbertConfig(**albertparams))
예제 #15
0
    def __init__(self, config: BertConfig, num_hidden_layers=None):
        super().__init__()
        self.logger = get_logger(__name__)
        config.output_hidden_states = True
        self.embeddings = BertEmbeddings(config)
        num_hidden_layers = config.num_hidden_layers if num_hidden_layers is None else num_hidden_layers
        assert num_hidden_layers > 0, 'bert_layers must > 0'

        # 需要注意的是和原始transformer的BERT_Encoder的输出不一样
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states

        layer = BertLayer(config)
        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)])
        self.config = config
        self.num_hidden_layers = num_hidden_layers
        self.apply(self.init_bert_weights)
예제 #16
0
def get_modules(params_dict):
    modules = {}
    params = copy.deepcopy(params_dict)

    params["hidden_dropout_prob"] = params.pop("dropout")
    params["hidden_size"] = params.pop("embedding_size")

    # bert, roberta, electra self attentions have the same code.

    torch.manual_seed(1234)
    hf_module = BertEmbeddings(BertConfig(**params))
    modules["bert"] = hf_module

    albertparams = copy.deepcopy(params_dict)
    albertparams["hidden_dropout_prob"] = albertparams.pop("dropout")

    torch.manual_seed(1234)
    hf_module = AlbertEmbeddings(AlbertConfig(**albertparams))
    modules["albert"] = hf_module

    return modules
예제 #17
0
    def __init__(self, config):
        super().__init__(config)

        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.init_weights()
예제 #18
0
    def __init__(self, config):
        super().__init__()
        self.save_hyperparameters()

        bert_config = BertConfig(
            vocab_size=config["vocab_size"],
            hidden_size=config["hidden_size"],
            num_hidden_layers=config["num_layers"],
            num_attention_heads=config["num_heads"],
            intermediate_size=config["hidden_size"] * config["mlp_ratio"],
            max_position_embeddings=config["max_text_len"],
            hidden_dropout_prob=config["drop_rate"],
            attention_probs_dropout_prob=config["drop_rate"],
        )
        self.tempeture_max_OT = config['tempeture_max_OT']
        self.text_embeddings = BertEmbeddings(bert_config)
        self.text_embeddings.apply(objectives.init_weights)

        self.token_type_embeddings = nn.Embedding(2, config["hidden_size"])
        self.token_type_embeddings.apply(objectives.init_weights)

        import vilt.modules.vision_transformer as vit

        if self.hparams.config["load_path"] == "":
            self.transformer = getattr(vit, self.hparams.config["vit"])(
                pretrained=config["pretrained_flag"], config=self.hparams.config)
        else:
            self.transformer = getattr(vit, self.hparams.config["vit"])(
                pretrained=False, config=self.hparams.config
            )

        self.pooler = heads.Pooler(config["hidden_size"])
        self.pooler.apply(objectives.init_weights)

        if config["loss_names"]["mlm"] > 0:
            self.mlm_score = heads.MLMHead(bert_config)
            self.mlm_score.apply(objectives.init_weights)

        if config["loss_names"]["itm"] > 0:
            self.itm_score = heads.ITMHead(config["hidden_size"])
            self.itm_score.apply(objectives.init_weights)

        if config["loss_names"]["mpp"] > 0:
            self.mpp_score = heads.MPPHead(bert_config)
            self.mpp_score.apply(objectives.init_weights)

        # ===================== Downstream ===================== #
        if (
            self.hparams.config["load_path"] != ""
            and not self.hparams.config["test_only"]
        ):
            ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu")
            state_dict = ckpt["state_dict"]
            self.load_state_dict(state_dict, strict=False)
            print(f'Loading checkpoint from {self.hparams.config["load_path"]}')

        hs = self.hparams.config["hidden_size"]

        if self.hparams.config["loss_names"]["vqa"] > 0:
            vs = self.hparams.config["vqav2_label_size"]
            self.vqa_classifier = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.LayerNorm(hs * 2),
                nn.GELU(),
                nn.Linear(hs * 2, vs),
            )
            self.vqa_classifier.apply(objectives.init_weights)

        if self.hparams.config["loss_names"]["nlvr2"] > 0:
            self.nlvr2_classifier = nn.Sequential(
                nn.Linear(hs * 2, hs * 2),
                nn.LayerNorm(hs * 2),
                nn.GELU(),
                nn.Linear(hs * 2, 2),
            )
            self.nlvr2_classifier.apply(objectives.init_weights)
            emb_data = self.token_type_embeddings.weight.data
            self.token_type_embeddings = nn.Embedding(3, hs)
            self.token_type_embeddings.apply(objectives.init_weights)
            self.token_type_embeddings.weight.data[0, :] = emb_data[0, :]
            self.token_type_embeddings.weight.data[1, :] = emb_data[1, :]
            self.token_type_embeddings.weight.data[2, :] = emb_data[1, :]

        if self.hparams.config["loss_names"]["irtr"] > 0:
            self.rank_output = nn.Linear(hs, 1)
            self.rank_output.weight.data = self.itm_score.fc.weight.data[1:, :]
            self.rank_output.bias.data = self.itm_score.fc.bias.data[1:]
            self.margin = 0.2
            for p in self.itm_score.parameters():
                p.requires_grad = False

        vilt_utils.set_metrics(self)
        self.current_tasks = list()

        # ===================== load downstream (test_only) ======================

        if self.hparams.config["load_path"] != "" and self.hparams.config["test_only"]:
            ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu")
            state_dict = ckpt["state_dict"]
            self.load_state_dict(state_dict, strict=False)
            print(f'Loading checkpoint from {self.hparams.config["load_path"]}')
예제 #19
0
class ViLTransformerSS(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.save_hyperparameters()

        bert_config = BertConfig(
            vocab_size=config["vocab_size"],
            hidden_size=config["hidden_size"],
            num_hidden_layers=config["num_layers"],
            num_attention_heads=config["num_heads"],
            intermediate_size=config["hidden_size"] * config["mlp_ratio"],
            max_position_embeddings=config["max_text_len"],
            hidden_dropout_prob=config["drop_rate"],
            attention_probs_dropout_prob=config["drop_rate"],
        )
        self.tempeture_max_OT = config['tempeture_max_OT']
        self.text_embeddings = BertEmbeddings(bert_config)
        self.text_embeddings.apply(objectives.init_weights)

        self.token_type_embeddings = nn.Embedding(2, config["hidden_size"])
        self.token_type_embeddings.apply(objectives.init_weights)

        import vilt.modules.vision_transformer as vit

        if self.hparams.config["load_path"] == "":
            self.transformer = getattr(vit, self.hparams.config["vit"])(
                pretrained=config["pretrained_flag"], config=self.hparams.config)
        else:
            self.transformer = getattr(vit, self.hparams.config["vit"])(
                pretrained=False, config=self.hparams.config
            )

        self.pooler = heads.Pooler(config["hidden_size"])
        self.pooler.apply(objectives.init_weights)

        if config["loss_names"]["mlm"] > 0:
            self.mlm_score = heads.MLMHead(bert_config)
            self.mlm_score.apply(objectives.init_weights)

        if config["loss_names"]["itm"] > 0:
            self.itm_score = heads.ITMHead(config["hidden_size"])
            self.itm_score.apply(objectives.init_weights)

        if config["loss_names"]["mpp"] > 0:
            self.mpp_score = heads.MPPHead(bert_config)
            self.mpp_score.apply(objectives.init_weights)

        # ===================== Downstream ===================== #
        if (
            self.hparams.config["load_path"] != ""
            and not self.hparams.config["test_only"]
        ):
            ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu")
            state_dict = ckpt["state_dict"]
            self.load_state_dict(state_dict, strict=False)
            print(f'Loading checkpoint from {self.hparams.config["load_path"]}')

        hs = self.hparams.config["hidden_size"]

        if self.hparams.config["loss_names"]["vqa"] > 0:
            vs = self.hparams.config["vqav2_label_size"]
            self.vqa_classifier = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.LayerNorm(hs * 2),
                nn.GELU(),
                nn.Linear(hs * 2, vs),
            )
            self.vqa_classifier.apply(objectives.init_weights)

        if self.hparams.config["loss_names"]["nlvr2"] > 0:
            self.nlvr2_classifier = nn.Sequential(
                nn.Linear(hs * 2, hs * 2),
                nn.LayerNorm(hs * 2),
                nn.GELU(),
                nn.Linear(hs * 2, 2),
            )
            self.nlvr2_classifier.apply(objectives.init_weights)
            emb_data = self.token_type_embeddings.weight.data
            self.token_type_embeddings = nn.Embedding(3, hs)
            self.token_type_embeddings.apply(objectives.init_weights)
            self.token_type_embeddings.weight.data[0, :] = emb_data[0, :]
            self.token_type_embeddings.weight.data[1, :] = emb_data[1, :]
            self.token_type_embeddings.weight.data[2, :] = emb_data[1, :]

        if self.hparams.config["loss_names"]["irtr"] > 0:
            self.rank_output = nn.Linear(hs, 1)
            self.rank_output.weight.data = self.itm_score.fc.weight.data[1:, :]
            self.rank_output.bias.data = self.itm_score.fc.bias.data[1:]
            self.margin = 0.2
            for p in self.itm_score.parameters():
                p.requires_grad = False

        vilt_utils.set_metrics(self)
        self.current_tasks = list()

        # ===================== load downstream (test_only) ======================

        if self.hparams.config["load_path"] != "" and self.hparams.config["test_only"]:
            ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu")
            state_dict = ckpt["state_dict"]
            self.load_state_dict(state_dict, strict=False)
            print(f'Loading checkpoint from {self.hparams.config["load_path"]}')

    def infer(
        self,
        batch,
        mask_text=False,
        mask_image=False,
        image_token_type_idx=1,
        image_embeds=None,
        image_masks=None,
    ):
        if f"image_{image_token_type_idx - 1}" in batch:
            imgkey = f"image_{image_token_type_idx - 1}"
        else:
            imgkey = "image"

        do_mlm = "_mlm" if mask_text else ""
        text_ids = batch[f"text_ids{do_mlm}"]
        text_labels = batch[f"text_labels{do_mlm}"]
        text_masks = batch[f"text_masks"]
        text_embeds = self.text_embeddings(text_ids)

        if image_embeds is None and image_masks is None:
            img = batch[imgkey][0]
            (
                image_embeds,
                image_masks,
                patch_index,
                image_labels,
            ) = self.transformer.visual_embed(
                img,
                max_patch_len=self.hparams.config["max_patch_len"],
                max_image_len=self.hparams.config["max_image_len"],
                mask_it=mask_image,
            )
        else:
            patch_index, image_labels = (
                None,
                None,
            )

        text_embeds, image_embeds = (
            text_embeds + self.token_type_embeddings(torch.zeros_like(text_masks)),
            image_embeds
            + self.token_type_embeddings(
                torch.full_like(image_masks, image_token_type_idx)
            ),
        )

        co_embeds = torch.cat([text_embeds, image_embeds], dim=1)
        co_masks = torch.cat([text_masks, image_masks], dim=1)

        x = co_embeds

        for i, blk in enumerate(self.transformer.blocks):
            x, _attn = blk(x, mask=co_masks)

        x = self.transformer.norm(x)
        text_feats, image_feats = (
            x[:, : text_embeds.shape[1]],
            x[:, text_embeds.shape[1] :],
        )
        cls_feats = self.pooler(x)

        ret = {
            "text_feats": text_feats,
            "image_feats": image_feats,
            "cls_feats": cls_feats,
            "raw_cls_feats": x[:, 0],
            "image_labels": image_labels,
            "image_masks": image_masks,
            "text_labels": text_labels,
            "text_ids": text_ids,
            "text_masks": text_masks,
            "patch_index": patch_index,
        }

        return ret

    def forward(self, batch):
        ret = dict()
        if len(self.current_tasks) == 0:
            ret.update(self.infer(batch))
            return ret

        # Masked Language Modeling
        if "mlm" in self.current_tasks:
            ret.update(objectives.compute_mlm(self, batch))

        # Masked Patch Prediction
        if "mpp" in self.current_tasks:
            ret.update(objectives.compute_mpp(self, batch))

        # Image Text Matching
        if "itm" in self.current_tasks:
            if self.tempeture_max_OT:
                if self.trainer:
                    temp = self.trainer.global_step / self.trainer.max_steps
                else:
                    temp = 0 # for
                ret.update(objectives.compute_itm_wpa_tmp_max_ot(self, batch, temp))
            else:
                ret.update(objectives.compute_itm_wpa(self, batch))

        # Visual Question Answering
        if "vqa" in self.current_tasks:
            ret.update(objectives.compute_vqa(self, batch))

        # Natural Language for Visual Reasoning 2
        if "nlvr2" in self.current_tasks:
            ret.update(objectives.compute_nlvr2(self, batch))

        # Image Retrieval and Text Retrieval
        if "irtr" in self.current_tasks:
            ret.update(objectives.compute_irtr(self, batch))

        return ret

    def training_step(self, batch, batch_idx):
        vilt_utils.set_task(self)
        output = self(batch)
        total_loss = sum([v for k, v in output.items() if "loss" in k])

        return total_loss

    def training_epoch_end(self, outs):
        vilt_utils.epoch_wrapup(self)

    def validation_step(self, batch, batch_idx):
        vilt_utils.set_task(self)
        output = self(batch)

    def validation_epoch_end(self, outs):
        vilt_utils.epoch_wrapup(self)

    def test_step(self, batch, batch_idx):
        vilt_utils.set_task(self)
        output = self(batch)
        ret = dict()

        if self.hparams.config["loss_names"]["vqa"] > 0:
            ret.update(objectives.vqa_test_step(self, batch, output))

        return ret

    def test_epoch_end(self, outs):
        model_name = self.hparams.config["load_path"].split("/")[-1][:-5]

        if self.hparams.config["loss_names"]["vqa"] > 0:
            objectives.vqa_test_wrapup(outs, model_name)
        vilt_utils.epoch_wrapup(self)

    def configure_optimizers(self):
        return vilt_utils.set_schedule(self)
예제 #20
0
    for line in f.readlines():
        line = eval(line.strip())
        data.append((line["abst"], line["title"]))
        i += 1
        if i == 10:
            break

#%%
max_len=512



tokenizer = BertTokenizer.from_pretrained("unilm_chinese")
batch_tensors = prepare_batches(data, tokenizer, max_len)


model = UnilmForSeq2Seq.from_pretrained("unilm_chinese")


#%%
loss = model(**batch_tensors)


#%% 
from transformers.models.bert.modeling_bert import BertEmbeddings
from configuration_unilm import UnilmConfig

config = UnilmConfig.from_pretrained("unilm_chinese")

embeddings = BertEmbeddings(config)