示例#1
0
 def __init__(self, args):
     super(Model, self).__init__()
     self.args = args
     # self.cutoffs = [20000, 60000]
     self.cutoffs = [10000, 20000, 40000, 60000, 100000]
     self.n_V = args.n_token
     self.n_e = args.n_e or args.n_proj
     self.n_d = args.n_d
     self.depth = args.depth
     self.drop = nn.Dropout(args.dropout)
     self.embedding_layer = AdaptiveEmbedding(
         self.n_V,
         self.n_e,
         self.n_d,
         self.cutoffs,
         div_val=args.div_val,
         div_freq=2,
         dropout=args.dropout_e,
     )
     self.rnn = sru.SRU(
         self.n_d,
         self.n_d,
         self.depth,
         projection_size=args.n_proj,
         dropout=args.dropout,
         highway_bias=args.bias,
         layer_norm=args.layer_norm,
         rescale=args.rescale,
         custom_m=flop.ProjectedLinear(self.n_d,
                                       self.n_d * 3,
                                       proj_features=args.n_proj,
                                       bias=False),
     )
     self.output_layer = AdaptiveLogSoftmax(
         self.n_V,
         self.n_e,
         self.n_d,
         self.cutoffs,
         div_val=args.div_val,
         div_freq=2,
         dropout=args.dropout_e,
         keep_order=False,
     )
     self.init_weights()
     if not args.not_tie:
         self.tie_weights()
示例#2
0
class Model(nn.Module):
    def __init__(self, args):
        super(Model, self).__init__()
        self.args = args
        # self.cutoffs = [20000, 60000]
        self.cutoffs = [10000, 20000, 40000, 60000, 100000]
        self.n_V = args.n_token
        self.n_e = args.n_e or args.n_proj
        self.n_d = args.n_d
        self.depth = args.depth
        self.drop = nn.Dropout(args.dropout)
        self.embedding_layer = AdaptiveEmbedding(
            self.n_V,
            self.n_e,
            self.n_d,
            self.cutoffs,
            div_val=args.div_val,
            div_freq=2,
            dropout=args.dropout_e,
        )
        self.rnn = sru.SRU(
            self.n_d,
            self.n_d,
            self.depth,
            dropout=args.dropout,
            highway_bias=args.bias,
            layer_norm=args.layer_norm,
            rescale=args.rescale,
            custom_m=CustomLinear(self.n_d, self.n_d * 3, bias=False),
        )
        self.output_layer = AdaptiveLogSoftmax(
            self.n_V,
            self.n_e,
            self.n_d,
            self.cutoffs,
            div_val=args.div_val,
            div_freq=2,
            dropout=args.dropout_e,
            keep_order=False,
        )
        self.init_weights()
        if not args.not_tie:
            self.tie_weights()

    def tie_weights(self):
        for i in range(len(self.output_layer.out_layers)):
            self.embedding_layer.emb_layers[
                i].weight = self.output_layer.out_layers[i].weight

        for i in range(len(self.output_layer.out_projs)):
            self.embedding_layer.emb_projs[i] = self.output_layer.out_projs[i]

        if hasattr(self.embedding_layer, "masks") and hasattr(
                self.output_layer, "masks"):
            delattr(self.output_layer, "masks")
            setattr(self.output_layer, "masks", self.embedding_layer.masks)

    def init_weights(self, init_range=0.03, reinit_rnn=False):
        params = list(self.embedding_layer.parameters()) + list(
            self.output_layer.parameters())
        for p in params:
            if p.dim() > 1:  # matrix
                p.data.uniform_(-init_range, init_range)
            else:
                p.data.zero_()
        if reinit_rnn:
            for p in self.rnn.parameters():
                if p.dim() > 1:  # matrix
                    p.data.uniform_(-init_range, init_range)

    def forward(self, x, y, hidden):
        emb = self.drop(self.embedding_layer(x))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        output = output.view(-1, output.size(2))
        loss = self.output_layer(output, y.view(-1))
        loss = loss.view(y.size(0), -1)
        return loss, hidden

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        zeros = weight.new(self.depth, batch_size, self.n_d).zero_()
        return zeros