示例#1
0
 def build_model(self):
     """
     Construct the model.
     """
     num_classes = len(self.class_list)
     return BertWrapper(BertModel.from_pretrained(self.pretrained_path),
                        num_classes)
示例#2
0
    def __init__(self, opt, dictionary):
        from parlai.agents.bert_ranker.helpers import BertWrapper

        try:
            from pytorch_pretrained_bert import BertModel
        except ImportError:
            raise Exception(
                "BERT rankers needs pytorch-pretrained-BERT installed. "
                "\npip install pytorch-pretrained-bert")
        super().__init__()
        self.opt = opt
        self.pad_idx = dictionary[PAD_TOKEN]
        self.ctx_bert = BertWrapper(
            bert_model=BertModel.from_pretrained(BERT_ID),
            output_dim=opt.bert_dim,
            add_transformer_layer=opt.bert_add_transformer_layer,
        )
        self.cand_bert = BertWrapper(
            bert_model=BertModel.from_pretrained(BERT_ID),
            output_dim=opt.bert_dim,
            add_transformer_layer=opt.bert_add_transformer_layer,
        )

        # Reset the embeddings for the until-now unused BERT tokens
        orig_embedding_weights = BertModel.from_pretrained(
            BERT_ID).embeddings.word_embeddings.weight
        mean_val = orig_embedding_weights.mean().item()
        std_val = orig_embedding_weights.std().item()
        unused_tokens = [
            START_OF_COMMENT, PARLAI_PAD_TOKEN, EMPTYPERSONA_TOKEN
        ]
        unused_token_idxes = [dictionary[token] for token in unused_tokens]
        for token_idx in unused_token_idxes:
            rand_embedding = orig_embedding_weights.new_empty(
                (1, orig_embedding_weights.size(1))).normal_(mean=mean_val,
                                                             std=std_val)
            for embeddings in [
                    self.ctx_bert.bert_model.embeddings.word_embeddings,
                    self.cand_bert.bert_model.embeddings.word_embeddings,
            ]:
                embeddings.weight[token_idx] = rand_embedding
        self.ctx_bert.bert_model.embeddings.word_embeddings.weight.detach_()
        self.cand_bert.bert_model.embeddings.word_embeddings.weight.detach_()
示例#3
0
    def __init__(self, opt, dictionary):
        self.pad_idx = dictionary.pad_idx
        self.start_idx = dictionary.start_idx
        self.end_idx = dictionary.end_idx
        self.dictionary = dictionary
        print('super type:')
        print(super())
        print(inspect.getargspec(super().__init__))
        super().__init__(opt, dictionary)
        self.encoder = BertWrapper(
            BertModel.from_pretrained(opt['pretrained_path']),
            opt['embedding_size'],
            add_transformer_layer=opt['add_transformer_layer'],
            layer_pulled=opt['pull_from_layer'],
            aggregation=opt['bert_aggregation'])

        def reorder_encoder_states(self, encoder_states, indices):
            # no support for beam search at this time
            return None
示例#4
0
    def __init__(self, opt, dictionary):
        self.pad_idx = dictionary.pad_idx
        self.start_idx = dictionary.start_idx
        self.end_idx = dictionary.end_idx
        self.dictionary = dictionary
        self.embeddings = None
        super().__init__(self.pad_idx, self.start_idx, self.end_idx)
        if opt.get('n_positions'):
            # if the number of positions is explicitly provided, use that
            n_positions = opt['n_positions']
        else:
            # else, use the worst case from truncate
            n_positions = max(
                opt.get('truncate') or 0,
                opt.get('text_truncate') or 0,
                opt.get('label_truncate') or 0)
            if n_positions == 0:
                # default to 1024
                n_positions = 1024
        n_segments = opt.get('n_segments', 0)

        if n_positions < 0:
            raise ValueError('n_positions must be positive')

        self.encoder = BertWrapper(
            BertModel.from_pretrained(opt['pretrained_path']),
            opt['out_dim'],
            add_transformer_layer=opt['add_transformer_layer'],
            layer_pulled=opt['pull_from_layer'],
            aggregation=opt['bert_aggregation'])

        self.decoder = _build_decoder(
            opt,
            self.dictionary,
            self.embeddings,
            self.pad_idx,
            n_positions=n_positions,
        )
示例#5
0
 def build_model(self):
     num_classes = len(self.class_list)
     self.model = BertWrapper(
         BertModel.from_pretrained(self.pretrained_path), num_classes)
示例#6
0
 def build_model(self):
     num_classes = len(self.class_list)
     self.model = BertWrapper(
         BertModel.from_pretrained(self.opt['bert_id']), num_classes)