예제 #1
0
    def __init__(self, num_classes, hparams):
        super(G2PTransformer, self).__init__()
        self.bert = BertModel.from_pretrained('./bert/bert-base-chinese')
        self.poly_phoneme_classifier = Poly_Phoneme_Classifier(hparams)

        self.linear = nn.Linear(1024, num_classes)

        self.bert_embedding_features_dim = 768
        self.transformer_embedding_features_dim = 1324
        self.embedding_features_dim = 1024
        self.select_model_hidden_dim = 512

        self.linear_pre = nn.Sequential(
            nn.Linear(self.bert_embedding_features_dim, self.select_model_hidden_dim),
            parse_nk.LayerNormalization(self.select_model_hidden_dim),
            nn.ReLU(),
            nn.Linear(self.select_model_hidden_dim, self.transformer_embedding_features_dim),
            )

        self.linear_aft = nn.Sequential(
            nn.Linear(self.embedding_features_dim, self.select_model_hidden_dim),
            parse_nk.LayerNormalization(self.select_model_hidden_dim),
            nn.ReLU(),
            nn.Linear(self.select_model_hidden_dim, num_classes),
            )
예제 #2
0
    def __init__(self,
                 hparams,
                 num_layers=1,
                 num_heads=2,
                 d_kv=32,
                 d_ff=1024,
                 d_positional=None,
                 num_layers_position_only=0,
                 relu_dropout=0.1,
                 residual_dropout=0.1,
                 attention_dropout=0.1):
        super(Poly_Phoneme_Classifier, self).__init__()

        # V = args.embed_num
        self.num_layers_position_only = num_layers_position_only
        self.embedding_features_dim = 1024
        self.structure_features_dim = 300
        self.select_model_dim = self.embedding_features_dim + self.structure_features_dim
        self.select_model_hidden_dim = 512
        self.n_pinyin_symbols = hparams.n_pinyin_symbols

        d_k = d_v = d_kv

        # self.linear_pre = nn.Sequential(
        # nn.Linear(self.embedding_features_dim, self.select_model_hidden_dim),
        # parse_nk.LayerNormalization(self.select_model_hidden_dim),
        # nn.ReLU(),
        # nn.Linear(self.select_model_hidden_dim, self.embedding_features_dim),
        # )
        self.linear_pre = nn.Sequential(
            nn.Linear(self.select_model_dim, self.select_model_hidden_dim),
            parse_nk.LayerNormalization(self.select_model_hidden_dim),
            nn.ReLU(),
            nn.Linear(self.select_model_hidden_dim,
                      self.embedding_features_dim),
        )
        # self.linear_pre = nn.Linear(self.select_model_dim, self.select_model_hidden_dim)

        self.stacks = []
        for i in range(num_layers):
            attn = parse_nk.MultiHeadAttention(
                num_heads,
                self.embedding_features_dim,
                d_k,
                d_v,
                residual_dropout=residual_dropout,
                attention_dropout=attention_dropout,
                d_positional=d_positional)
            if d_positional is None:
                ff = parse_nk.PositionwiseFeedForward(
                    self.embedding_features_dim,
                    d_ff,
                    relu_dropout=relu_dropout,
                    residual_dropout=residual_dropout)
            else:
                ff = parse_nk.PartitionedPositionwiseFeedForward(
                    self.embedding_features_dim,
                    d_ff,
                    d_positional,
                    relu_dropout=relu_dropout,
                    residual_dropout=residual_dropout)

            self.add_module(f"select_attn_{i}", attn)
            self.add_module(f"select_ff_{i}", ff)
            self.stacks.append((attn, ff))

        self.linear_label = nn.Sequential(
            nn.Linear(self.embedding_features_dim,
                      self.select_model_hidden_dim),
            parse_nk.LayerNormalization(self.select_model_hidden_dim),
            nn.ReLU(),
            nn.Linear(self.select_model_hidden_dim, self.n_pinyin_symbols),
        )

        self.gumbel_softmax = Gumbel_Softmax()
        self.mask_softmax = Mask_Softmax()