コード例 #1
0
    def __init__(self, config):
        """Initialize the model with config dict.

        Args:
            config: python dict must contains the attributes below:
                config.bert_model_path: pretrained model path or model type
                    e.g. 'bert-base-chinese'
                config.hidden_size: The same as BERT model, usually 768
                config.num_classes: int, e.g. 2
                config.dropout: float between 0 and 1
        """
        super().__init__()
        if 'xl' in config.model_type:
            self.bert = AutoModel.from_pretrained(config.bert_model_path)
        else:
            self.bert = BertModel.from_pretrained(config.bert_model_path)

        for param in self.bert.parameters():
            param.requires_grad = True
        self.dropout = nn.Dropout(config.dropout)
        self.linear = nn.Linear(4, config.num_classes)
        self.num_classes = config.num_classes

        self.dim_capsule = config.dim_capsule
        self.num_compressed_capsule = config.num_compressed_capsule
        self.ngram_size = [2, 4, 8]
        self.convs_doc = nn.ModuleList([
            nn.Conv1d(config.max_seq_len, 32, K, stride=2)
            for K in self.ngram_size
        ])
        torch.nn.init.xavier_uniform_(self.convs_doc[0].weight)
        torch.nn.init.xavier_uniform_(self.convs_doc[1].weight)
        torch.nn.init.xavier_uniform_(self.convs_doc[2].weight)

        self.primary_capsules_doc = PrimaryCaps(num_capsules=self.dim_capsule,
                                                in_channels=32,
                                                out_channels=32,
                                                kernel_size=1,
                                                stride=1)

        self.flatten_capsules = FlattenCaps()

        if config.hidden_size == 768:
            self.W_doc = nn.Parameter(
                torch.FloatTensor(147328, self.num_compressed_capsule))
        else:  #1024
            self.W_doc = nn.Parameter(
                torch.FloatTensor(196480, self.num_compressed_capsule))
        torch.nn.init.xavier_uniform_(self.W_doc)

        self.fc_capsules_doc_child = FCCaps(
            config,
            output_capsule_num=config.num_classes,
            input_capsule_num=self.num_compressed_capsule,
            in_channels=self.dim_capsule,
            out_channels=self.dim_capsule)
コード例 #2
0
    def __init__(self, config):
        super(BertSupportNetX, self).__init__()

        self.encoder = BertModel.from_pretrained(config.bert_model_path)
        self.config = config  # 就是args
        self.max_query_length = self.config.max_query_len
        self.input_dim = config.hidden_size
        self.dropout_size = config.dropout
        self.dropout = nn.Dropout(self.dropout_size)

        self.dim_capsule = config.dim_capsule
        self.num_compressed_capsule = config.num_compressed_capsule
        self.ngram_size = [2, 4, 8]
        self.convs_doc = nn.ModuleList([
            nn.Conv1d(config.max_seq_len, 32, K, stride=2)
            for K in self.ngram_size
        ])
        torch.nn.init.xavier_uniform_(self.convs_doc[0].weight)
        torch.nn.init.xavier_uniform_(self.convs_doc[1].weight)
        torch.nn.init.xavier_uniform_(self.convs_doc[2].weight)

        self.primary_capsules_doc = PrimaryCaps(num_capsules=self.dim_capsule,
                                                in_channels=32,
                                                out_channels=32,
                                                kernel_size=1,
                                                stride=1)

        self.flatten_capsules = FlattenCaps()

        self.W_doc = nn.Parameter(
            torch.FloatTensor(49024, self.num_compressed_capsule))
        torch.nn.init.xavier_uniform_(self.W_doc)

        self.fc_capsules_doc_child = FCCaps(
            config,
            output_capsule_num=config.num_classes,
            input_capsule_num=self.num_compressed_capsule,
            in_channels=self.dim_capsule,
            out_channels=self.dim_capsule)

        self.start_linear = nn.Linear(self.input_dim * 2, 1)
        self.end_linear = nn.Linear(self.input_dim * 2, 1)
        self.type_linear = nn.Linear(self.input_dim,
                                     config.num_classes)  # yes/no/ans/unknown
        self.sp_linear = nn.Linear(self.input_dim, 1)

        self.cache_S = 0
        self.cache_mask = None