def build_model(self): """ Construct the model. """ num_classes = len(self.class_list) return BertWrapper(BertModel.from_pretrained(self.pretrained_path), num_classes)
def __init__(self, opt, dictionary): from parlai.agents.bert_ranker.helpers import BertWrapper try: from pytorch_pretrained_bert import BertModel except ImportError: raise Exception( "BERT rankers needs pytorch-pretrained-BERT installed. " "\npip install pytorch-pretrained-bert") super().__init__() self.opt = opt self.pad_idx = dictionary[PAD_TOKEN] self.ctx_bert = BertWrapper( bert_model=BertModel.from_pretrained(BERT_ID), output_dim=opt.bert_dim, add_transformer_layer=opt.bert_add_transformer_layer, ) self.cand_bert = BertWrapper( bert_model=BertModel.from_pretrained(BERT_ID), output_dim=opt.bert_dim, add_transformer_layer=opt.bert_add_transformer_layer, ) # Reset the embeddings for the until-now unused BERT tokens orig_embedding_weights = BertModel.from_pretrained( BERT_ID).embeddings.word_embeddings.weight mean_val = orig_embedding_weights.mean().item() std_val = orig_embedding_weights.std().item() unused_tokens = [ START_OF_COMMENT, PARLAI_PAD_TOKEN, EMPTYPERSONA_TOKEN ] unused_token_idxes = [dictionary[token] for token in unused_tokens] for token_idx in unused_token_idxes: rand_embedding = orig_embedding_weights.new_empty( (1, orig_embedding_weights.size(1))).normal_(mean=mean_val, std=std_val) for embeddings in [ self.ctx_bert.bert_model.embeddings.word_embeddings, self.cand_bert.bert_model.embeddings.word_embeddings, ]: embeddings.weight[token_idx] = rand_embedding self.ctx_bert.bert_model.embeddings.word_embeddings.weight.detach_() self.cand_bert.bert_model.embeddings.word_embeddings.weight.detach_()
def __init__(self, opt, dictionary): self.pad_idx = dictionary.pad_idx self.start_idx = dictionary.start_idx self.end_idx = dictionary.end_idx self.dictionary = dictionary print('super type:') print(super()) print(inspect.getargspec(super().__init__)) super().__init__(opt, dictionary) self.encoder = BertWrapper( BertModel.from_pretrained(opt['pretrained_path']), opt['embedding_size'], add_transformer_layer=opt['add_transformer_layer'], layer_pulled=opt['pull_from_layer'], aggregation=opt['bert_aggregation']) def reorder_encoder_states(self, encoder_states, indices): # no support for beam search at this time return None
def __init__(self, opt, dictionary): self.pad_idx = dictionary.pad_idx self.start_idx = dictionary.start_idx self.end_idx = dictionary.end_idx self.dictionary = dictionary self.embeddings = None super().__init__(self.pad_idx, self.start_idx, self.end_idx) if opt.get('n_positions'): # if the number of positions is explicitly provided, use that n_positions = opt['n_positions'] else: # else, use the worst case from truncate n_positions = max( opt.get('truncate') or 0, opt.get('text_truncate') or 0, opt.get('label_truncate') or 0) if n_positions == 0: # default to 1024 n_positions = 1024 n_segments = opt.get('n_segments', 0) if n_positions < 0: raise ValueError('n_positions must be positive') self.encoder = BertWrapper( BertModel.from_pretrained(opt['pretrained_path']), opt['out_dim'], add_transformer_layer=opt['add_transformer_layer'], layer_pulled=opt['pull_from_layer'], aggregation=opt['bert_aggregation']) self.decoder = _build_decoder( opt, self.dictionary, self.embeddings, self.pad_idx, n_positions=n_positions, )
def build_model(self): num_classes = len(self.class_list) self.model = BertWrapper( BertModel.from_pretrained(self.pretrained_path), num_classes)
def build_model(self): num_classes = len(self.class_list) self.model = BertWrapper( BertModel.from_pretrained(self.opt['bert_id']), num_classes)