예제 #1
0
    def __init__(self, config: Dict):
        super().__init__()

        self.config = config
        self.model_config = DistilBertConfig(**self.config["model"])
        self.model = DistilBertModel(self.model_config)
        self.criterion = nn.CosineEmbeddingLoss(margin=0.0, reduction='mean')
예제 #2
0
def main():
    # define parser and arguments
    args = get_train_test_args()
    util.set_seed(args.seed)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    DistilBert = DistilBertModel.from_pretrained('distilbert-base-uncased')
    Experts = [DistilBertQA(DistilBertModel.from_pretrained('distilbert-base-uncased')).to(device) for _ in range(args.num_experts)]
    tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
    gate_model = GateNetwork(384, 3,3, DistilBert.config).to(device)
    print(f'Args: {json.dumps(vars(args), indent=4, sort_keys=True)}')
    if args.do_train:
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)
        args.save_dir = util.get_save_dir(args.save_dir, args.run_name)
        log = util.get_logger(args.save_dir, 'log_train')
        log.info(f'Args: {json.dumps(vars(args), indent=4, sort_keys=True)}')
        log.info("Preparing Training Data...")
        args.device = device
        trainer = train.Trainer(args, log)
        train_dataset, _ = get_dataset(args, args.train_datasets, args.train_dir, tokenizer, 'train')
        log.info("Preparing Validation Data...")
        val_dataset, val_dict = get_dataset(args, args.train_datasets, args.val_dir, tokenizer, 'val')
        train_loader = DataLoader(train_dataset,
                                batch_size=args.batch_size,
                                sampler=RandomSampler(train_dataset))
        val_loader = DataLoader(val_dataset,
                                batch_size=1,
                                sampler=SequentialSampler(val_dataset))
        best_scores = trainer.train(Experts, gate_model, train_loader, val_loader, val_dict, args.num_experts)
    if args.do_eval:
        split_name = 'test' if 'test' in args.eval_dir else 'validation'
        log = util.get_logger(args.save_dir, f'log_{split_name}')
        trainer = train.Trainer(args, log)
        # load model
        restore_model("",args.num_experts, Experts, gate_model)
        eval_dataset, eval_dict = get_dataset(args, args.eval_datasets, args.eval_dir, tokenizer, split_name)
        eval_loader = DataLoader(eval_dataset,
                                 batch_size=1,
                                 sampler=SequentialSampler(eval_dataset))
        args.device = device
        eval_preds, eval_scores = trainer.evaluate(Experts, gate_model, eval_loader,
                                                   eval_dict, return_preds=True,
                                                   split=split_name)
        results_str = ', '.join(f'{k}: {v:05.2f}' for k, v in eval_scores.items())
        log.info(f'Eval {results_str}')
        # Write submission file
        sub_path = os.path.join(args.save_dir, split_name + '_' + args.sub_file)
        log.info(f'Writing submission file to {sub_path}...')
        with open(sub_path, 'w', newline='', encoding='utf-8') as csv_fh:
            csv_writer = csv.writer(csv_fh, delimiter=',')
            csv_writer.writerow(['Id', 'Predicted'])
            for uuid in sorted(eval_preds):
                csv_writer.writerow([uuid, eval_preds[uuid]])
예제 #3
0
        def create_and_check_distilbert_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
            model = DistilBertModel(config=config)
            model.eval()
            (sequence_output,) = model(input_ids, input_mask)
            (sequence_output,) = model(input_ids)

            result = {
                "sequence_output": sequence_output,
            }
            self.parent.assertListEqual(
                list(result["sequence_output"].size()),
                [self.batch_size, self.seq_length, self.hidden_size])
예제 #4
0
파일: nl2sql.py 프로젝트: vzhong/gazp
    def __init__(self, args, ext):
        super().__init__(args, ext)
        self.conv = converter.Converter(tables=getattr(args, 'tables',
                                                       'data/spider/tables'),
                                        db=getattr(args, 'db',
                                                   'data/database'))
        self.bert_tokenizer = DistilBertTokenizer.from_pretrained(
            args.dcache + '/vocab.txt', cache_dir=args.dcache)
        self.bert_embedder = DistilBertModel.from_pretrained(
            args.dcache, cache_dir=args.dcache)
        self.value_bert_embedder = DistilBertModel.from_pretrained(
            args.dcache, cache_dir=args.dcache)
        self.denc = 768
        self.demb = args.demb
        self.sql_vocab = ext['sql_voc']
        self.sql_emb = nn.Embedding.from_pretrained(ext['sql_emb'],
                                                    freeze=False)
        self.pad_id = self.sql_vocab.word2index('PAD')

        self.dropout = nn.Dropout(args.dropout)
        self.bert_dropout = nn.Dropout(args.bert_dropout)
        self.table_sa_scorer = nn.Linear(self.denc, 1)
        self.col_sa_scorer = nn.Linear(self.denc, 1)
        self.col_trans = nn.LSTM(self.denc,
                                 self.demb // 2,
                                 bidirectional=True,
                                 batch_first=True)
        self.table_trans = nn.LSTM(self.denc,
                                   args.drnn,
                                   bidirectional=True,
                                   batch_first=True)
        self.pointer_decoder = decoder.PointerDecoder(
            demb=self.demb,
            denc=2 * args.drnn,
            ddec=args.drnn,
            dropout=args.dec_dropout,
            num_layers=args.num_layers)

        self.utt_trans = nn.LSTM(self.denc,
                                 self.demb // 2,
                                 bidirectional=True,
                                 batch_first=True)
        self.value_decoder = decoder.PointerDecoder(demb=self.demb,
                                                    denc=self.denc,
                                                    ddec=args.drnn,
                                                    dropout=args.dec_dropout,
                                                    num_layers=args.num_layers)

        self.evaluator = evaluation.Evaluator()
        if 'reranker' in ext:
            self.reranker = ext['reranker']
        else:
            self.reranker = rank_max.Module(args, ext, remove_invalid=True)
예제 #5
0
    def __init__(self, pretrained=True, **kwargs):
        super().__init__()
        hidden_dimension = 32

        if pretrained:
            self.bert = DistilBertModel.from_pretrained(
                "distilbert-base-uncased")
        else:
            self.bert = DistilBertModel(DistilBertConfig())
        self.tokenizer = DistilBertTokenizer.from_pretrained(
            "distilbert-base-uncased")
        self.pre_classifier = nn.Linear(self.bert.config.dim, hidden_dimension)
        self.classifier = nn.Linear(hidden_dimension, 1)
예제 #6
0
    def __init__(self, config):
        super(DistilBertForMultiLabelSequenceClassification,
              self).__init__(config)
        self.num_labels = config.num_labels
        self.distilbert = DistilBertModel(config)
        self.pre_classifier = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size), nn.ReLU(),
            nn.Dropout(config.hidden_dropout_prob))
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.pos_weight = torch.Tensor(
            config.pos_weight).to(device) if config.use_pos_weight else None

        self.init_weights()
예제 #7
0
 def __init__(self, hidden_size, num_labels, drop_prob, freeze, use_img,
              img_size):
     super(DistilBERT, self).__init__()
     self.img_size = img_size
     self.use_img = use_img
     config = DistilBertConfig(vocab_size=119547)
     self.distilbert = DistilBertModel(config)
     for param in self.distilbert.parameters():
         param.requires_grad = not freeze
     self.classifier = layers.DistilBERTClassifier(hidden_size,
                                                   num_labels,
                                                   drop_prob=drop_prob,
                                                   use_img=use_img,
                                                   img_size=img_size)
예제 #8
0
    def __init__(self,
                 model_name=CFG.text_encoder_model,
                 pretrained=CFG.pretrained,
                 trainable=CFG.trainable):
        super().__init__()
        if pretrained:
            self.model = DistilBertModel.from_pretrained(model_name)
        else:
            self.model = DistilBertModel(config=DistilBertConfig())

        for p in self.model.parameters():
            p.requires_grad = trainable

        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0
예제 #9
0
    def __init__(self, n_outputs, size, pretrained_model_path=False):
        super(DistilBert, self).__init__()
        self.n_outputs = n_outputs
        self.size = size
        self.pretrained_model_path = pretrained_model_path

        if self.pretrained_model_path is False:
            self.huggingface_model = DistilBertModel.from_pretrained(
                f"distilbert-{size}-uncased")
        else:
            self.huggingface_model = DistilBertModel.from_pretrained(
                pretrained_model_path)
        self.dropout = nn.Dropout(0.1)  # hard coding
        self.out_proj = nn.Linear(self.huggingface_model.config.hidden_size,
                                  n_outputs)
예제 #10
0
    def __init__(self) -> None:
        from transformers import DistilBertTokenizer as BertTokenizer
        from transformers import DistilBertModel as BertModel

        pretrained_weights = 'distilbert-base-uncased'
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_weights)
        self.model = BertModel.from_pretrained(pretrained_weights)
예제 #11
0
def init_model():
    global tokenizer, bert_model
    MODEL_NAME = "distilbert-base-multilingual-cased"
    print("Loading Bert...")
    tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
    bert_model = DistilBertModel.from_pretrained(MODEL_NAME)
    print("Done")
    def load_model(self, state_path):
        """
        Initialises the model and loads saved state into the instance of the model.

        Parameters
        ----------
        state_path (str) - path pointing to the saved state.

        Returns
        -------
        Model (torch.nn.Module)
        """

        logging.info(f"Loading trained state from {state_path}")
        dbm = DistilBertModel.from_pretrained('distilbert-base-uncased',
                                              return_dict=True)
        device = torch.device(self.device)
        dbm.to(device)
        model = QAModel(transformer_model=dbm, device=device)

        # checkpoint = torch.load(state_path, map_location=device)
        model.load_state_dict(torch.load(state_path))
        model.eval()  # Switch to evaluation mode

        return model
 def __init__(self, config):
     super(DistilBertForQuestionAnswering, self).__init__(config)
     self.bert = DistilBertModel(config)
     self.qa_outputs = nn.Linear(config.hidden_size, 2)  # start/end
     self.dropout = nn.Dropout(0.3) #RekhaDist
     self.classifier = nn.Linear(config.hidden_size, config.num_labels)
     self.init_weights()
예제 #14
0
    def __init__(self, distilbert_config, args):
        super(DistilBertClassifier, self).__init__(distilbert_config)
        self.args = args
        self.distilbert = DistilBertModel.from_pretrained(args.model_name_or_path, config=distilbert_config)  # Load pretrained distilbert

        self.num_labels = distilbert_config.num_labels
        self.slot_classifier = FCLayer(distilbert_config.hidden_size, distilbert_config.num_labels, args.dropout_rate, use_activation=False)
예제 #15
0
def classify(text):
    print('start')

    path = settings.MEDIA_ROOT + "\distilbert.bin"
    MODEL_PATH = 'distilbert-base-uncased'
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

    encode = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=192,
        pad_to_max_length=True,
        truncation=True,
    )
    device = torch.device('cpu')
    tokens = encode['input_ids']
    tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(0)
    tokens = tokens.to(device)
    config = DistilBertConfig()
    model = Bert(DistilBertModel(config))

    model.load_state_dict(torch.load(path, map_location=device))
    model.to(device)

    output = model(tokens)
    output = output.cpu().detach().numpy()

    print(output)
    output = 0.0 if output < 0.5 else 1.0
    return output
 def __init__(self, hidden_dim, num_classes=2):
     super().__init__()
     self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased',
                                                 output_attentions=False,
                                                 output_hidden_states=False)
     self.linear = nn.Linear(768, hidden_dim)
     self.fc = nn.Linear(hidden_dim, num_classes)
예제 #17
0
파일: BERT.py 프로젝트: wiragotama/BEA2021
    def __init__(self, vocab: Vocabulary, n_labels: int,
                 torch_device: torch.device) -> None:
        """
        Args:
            vocab (Vocabulary)
            fc_u (int): the number of units of hidden layer for classification
            dropout_rate (float)
            n_labels (int): the number of labels
            torch_device (torch.device): device to use
        """
        super().__init__(vocab)
        self.emb_dim = 768
        self.distil_bert = DistilBertModel.from_pretrained(
            'distilbert-base-multilingual-cased')
        self.fc = nn.Linear(self.emb_dim, n_labels)

        # weight initialization, http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf
        # xavier is good for FNN
        torch.nn.init.xavier_uniform_(self.fc.weight)

        # loss
        self.loss = nn.NLLLoss()

        # for saving model
        self.param = {
            "class": "DistilBERTFinetuning",
            "emb_dim": self.emb_dim,
            "n_labels": n_labels
        }
        self.vocab = vocab
        self.running_device = torch_device
예제 #18
0
def build_model_pretrained(config):

    #Create different tokenizers for both source and target language.
    src_tokenizer = DistilBertTokenizer.from_pretrained(
        'distilbert-base-multilingual-cased')
    tgt_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    tgt_tokenizer.bos_token = '<s>'
    tgt_tokenizer.eos_token = '</s>'

    #encoder_config = DistilBertConfig.from_pretrained('distilbert-base-multilingual-cased')

    encoder = DistilBertModel.from_pretrained(
        'distilbert-base-multilingual-cased')

    if config.decoder.pretrained:
        decoder = BertForMaskedLM.from_pretrained('bert-base-uncased')
    else:

        decoder_config = BertConfig(vocab_size=tgt_tokenizer.vocab_size,
                                    is_decoder=True)
        decoder = BertForMaskedLM(decoder_config)

    model = TranslationModel(encoder, decoder)
    model.cuda()

    tokenizers = ED({'src': src_tokenizer, 'tgt': tgt_tokenizer})
    return model, tokenizers
예제 #19
0
class DistilBERT(nn.Module):
    """DistilBERT model to classify news

    Based on the paper:
    DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter
    by Victor Sanh, Lysandre Debut, Julien Chaumond, Thomas Wolf
    (https://arxiv.org/abs/1910.01108)
    """
    def __init__(self, hidden_size, num_labels, drop_prob, freeze, use_img,
                 img_size):
        super(DistilBERT, self).__init__()
        self.img_size = img_size
        self.use_img = use_img
        config = DistilBertConfig(vocab_size=119547)
        self.distilbert = DistilBertModel(config)
        for param in self.distilbert.parameters():
            param.requires_grad = not freeze
        self.classifier = layers.DistilBERTClassifier(hidden_size,
                                                      num_labels,
                                                      drop_prob=drop_prob,
                                                      use_img=use_img,
                                                      img_size=img_size)

    def forward(self, input_idxs, atten_masks):
        con_x = self.distilbert(input_ids=input_idxs,
                                attention_mask=atten_masks)[0][:, 0]
        # img_x = self.resnet18(images).view(-1, self.img_size) if self.use_img else None
        logit = self.classifier(con_x)
        log = torch.sigmoid(logit)

        return log
예제 #20
0
 def __init__(self):
     super().__init__()
     self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
     self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
     self.score_fc = nn.Linear(768, 11)
     self.regression_fc = nn.Linear(768, 1)
     self.sigmoid = nn.Sigmoid()
    def __init__(self,
                 model_name_or_path: str,
                 max_seq_length: int = 128,
                 do_lower_case: bool = True):
        super(BERT, self).__init__()
        self.config_keys = ['max_seq_length', 'do_lower_case']
        self.do_lower_case = do_lower_case
        self.fc = nn.Linear(1, 1)

        if max_seq_length > 510:
            logging.warning(
                "BERT only allows a max_seq_length of 510 (512 with special tokens). Value will be set to 510"
            )
            max_seq_length = 510
        self.max_seq_length = max_seq_length
        self.bert = DistilBertModel.from_pretrained(model_name_or_path)
        self.tokenizer = DistilBertTokenizer.from_pretrained(
            '/data/premnadh/Hybrid-QASystem/sentence_transformers/Vocab/DistilBert_Vocab.txt',
            do_lower_case=do_lower_case)
        # if(model_name_or_path is not None):
        #    self.tokenizer.save_vocabulary('/data/premnadh/Hybrid-QASystem/sentence_transformers/Vocab/DistilBert_Vocab.txt')
        self.cls_token_id = self.tokenizer.convert_tokens_to_ids(
            [self.tokenizer.cls_token])[0]
        self.sep_token_id = self.tokenizer.convert_tokens_to_ids(
            [self.tokenizer.sep_token])[0]
예제 #22
0
    def __init__(self, config):
        super().__init__(config)

        self.distilbert = DistilBertModel(config)

        '''
          Adding a fully-convolutional classifier inspired by DeepLab V3+ used for 
          semantic segmentation.
          Using a Atrous Spatial Pyramid Pooling (ASPP) module with dilated convolutions at 
          different dilation rates, in order to capture larger coarsear structures,
          possibly capturing better the structure of an answer.
          The ASPP modules halves the size of the input sequence, so we upsample x2 after
          the scoring layer.
        '''
        # Spatial Pyramid Pooling with dilated convs
        self.qa_aspp_r3 = Conv1d(config.dim, config.dim, 3, stride=2, dilation=6, groups=config.dim, padding=6)
        self.qa_aspp_r3_1x1 = Conv1d(config.dim, config.dim // 4, 1)

        self.qa_aspp_r6 = Conv1d(config.dim, config.dim, 3, stride=2, dilation=12, groups=config.dim, padding=12)
        self.qa_aspp_r6_1x1 = Conv1d(config.dim, config.dim // 4, 1)

        self.qa_aspp_r12 = Conv1d(config.dim, config.dim, 3, stride=2, dilation=18, groups=config.dim, padding=18)
        self.qa_aspp_r12_1x1 = Conv1d(config.dim, config.dim // 4, 1)

        self.qa_aspp_score = Conv1d(config.dim // 4 * 3, config.num_labels, 1)
        # self.LayerNorm_aspp = nn.LayerNorm(normalized_shape = [384,2])
        self.upsampling2D = nn.Upsample(scale_factor=2, mode='bilinear')

        assert config.num_labels == 2
        # self.dropout = nn.Dropout(config.qa_dropout)

        # self.LayerNorm = nn.LayerNorm(normalized_shape = [384,2])

        self.init_weights()
예제 #23
0
    def __init__(self, config):
        super().__init__(config)

        self.num_labels = config.num_labels
        self.vocab_size = config.vocab_size

        self.distilbert = DistilBertModel(config)
        # self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.vocab_transform = nn.Linear(config.dim, config.dim)
        self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
        self.vocab_projector = nn.Linear(config.dim, config.vocab_size)

        self.Q_cls = nn.ModuleDict()

        for T in range(2):
            # ModuleDict keys have to be strings..
            self.Q_cls['%d' % T] = nn.Sequential(
                nn.Linear(config.hidden_size + self.num_labels, 200),
                nn.ReLU(),
                nn.Linear(200, self.num_labels))

        self.g_cls = nn.Linear(config.hidden_size + self.num_labels, 
            self.config.num_labels)

        self.init_weights()
예제 #24
0
def get_distilkobert_model(no_cuda=False):
    model = DistilBertModel.from_pretrained('monologg/distilkobert')

    device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
    model.to(device)

    return model
예제 #25
0
    def __init__(self, config):
        self.mode = "train"
        self.config = config
        print(self.config)
        self.load_config()

        # bert tokenizer and model
        self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
        self.word2id = self.tokenizer.get_vocab()
        self.word_vocab = {value:key for key, value in self.word2id.items()}
        bert_model = DistilBertModel.from_pretrained('distilbert-base-cased')
        bert_model.transformer = None
        bert_model.encoder = None
        for param in bert_model.parameters():
            param.requires_grad = False

        self.online_net = Policy(config=self.config, bert_model=bert_model, word_vocab_size=len(self.word2id))
        self.target_net = Policy(config=self.config, bert_model=bert_model, word_vocab_size=len(self.word2id))
        self.online_net.train()
        self.target_net.train()
        self.update_target_net()
        for param in self.target_net.parameters():
            param.requires_grad = False
        if self.use_cuda:
            self.online_net.cuda()
            self.target_net.cuda()

        # optimizer
        self.optimizer = torch.optim.Adam(self.online_net.parameters(), lr=self.config['general']['training']['optimizer']['learning_rate'])
        self.clip_grad_norm = self.config['general']['training']['optimizer']['clip_grad_norm']

        # losses
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
예제 #26
0
    def __init__(self, vocab_size):
        super(DistilBertEncoder, self).__init__()
        self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.bert.resize_token_embeddings(vocab_size)

        for param in self.bert.parameters():
            param.requires_grad = False
예제 #27
0
def main():
    with dask.config.set(scheduler='synchronous'):
        data_dir = PosixPath("~/recsys2020").expanduser()
        ds_name = "user_sampled"
        input_file = data_dir / f"{ds_name}.parquet/"
        output_file = data_dir / f"{ds_name}_embeddings.parquet/"
        df = dd.read_parquet(str(input_file))

        meta = {
            'user_id': str,
            'tweet_id': str,
            'tokens': object,
            'embeddings': object
        }
        d = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        with torch.no_grad():
            if arch == 'distilbert':
                model = DistilBertModel.from_pretrained(
                    'distilbert-base-multilingual-cased',
                    output_hidden_states=True)
            elif arch == 'bert':
                model = BertModel.from_pretrained(
                    'bert-base-multilingual-cased', output_hidden_states=True)
            model = model.eval().to(d)
            df = df[['user_id', 'tweet_id',
                     'tokens']].map_partitions(embed_partition,
                                               d=d,
                                               model=model,
                                               meta=meta)
            del df['tokens']
            df.to_parquet(output_file)
예제 #28
0
    def __init__(self,
                 text_dim=1268 + 4,
                 hidden_dim=200,
                 img_dim=1000,
                 rep_dim=500,
                 output_dim=4):
        super(Basic, self).__init__()

        self.hidden_layer = nn.Linear(text_dim, hidden_dim)
        self.softmax = nn.Softmax(dim=-1)
        self.tanh = nn.Tanh()
        self.relu = nn.ReLU()
        self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased",
                                                    return_dict=True)

        self.image_model = torch.hub.load('pytorch/vision:v0.6.0',
                                          'resnet18',
                                          pretrained=True)

        self.main = nn.Sequential(
            nn.Linear(text_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

        self.image_main = nn.Sequential(nn.Linear(img_dim, rep_dim), )
예제 #29
0
파일: DistillBERT.py 프로젝트: MarkoKat/NLP
 def __init__(self):
     super(DistillBERTClass, self).__init__()
     self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
     self.pre_classifier = torch.nn.Linear(768, 768)
     self.dropout = torch.nn.Dropout(0.3)
     # self.classifier = torch.nn.Linear(768, 16)
     self.classifier = torch.nn.Linear(768, 9)
예제 #30
0
    def __init__(self, visual_encoder, N_data,
                 emb_dim=256, dropout=0, K=4096, T=0.07, m=0.5, gpu=None):
        super(CPD, self).__init__()

        self.visual_encoder = visual_encoder
        self.textual_encoder = DistilBertModel.from_pretrained(
            'distilbert-base-uncased')

        self.emb_dim = emb_dim
        self.dropout = dropout
        self._prepare_base_model()

        self.vis_emb = nn.Linear(self.feature_dim, emb_dim)
        self.text_emb = nn.Sequential(nn.Linear(
            768, emb_dim*2), nn.BatchNorm1d(emb_dim*2), nn.ReLU(), nn.Linear(emb_dim*2, emb_dim))

        self.N_data = N_data
        self.K = K
        self.T = T
        self.m = m
        self.unigrams = torch.ones(N_data)
        self.multinomial = AliasMethod(self.unigrams)
        self.multinomial.cuda(gpu)
        stdv = 1. / math.sqrt(emb_dim / 3)
        self.register_buffer('Z_v', torch.tensor([-1.0]))
        self.register_buffer('Z_t', torch.tensor([-1.0]))
        self.register_buffer('vis_memory', torch.rand(
            N_data, emb_dim).mul_(2 * stdv).add_(-stdv))
        self.register_buffer('text_memory', torch.rand(
            N_data, emb_dim).mul_(2 * stdv).add_(-stdv))