def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False): super(RNNModel, self).__init__() self.drop = nn.Dropout(dropout) self.encoder = nn.Embedding(ntoken, ninp) if rnn_type in ['LSTM', 'GRU']: self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout) elif rnn_type == 'DRNN': self.rnn = drnn.DRNN(ninp, nhid, nlayers, 0, 'GRU') self.decoder = nn.Linear(nhid, ntoken) if tie_weights: if nhid != ninp: raise ValueError( 'When using the tied flag, nhid must be equal to emsize') self.decoder.weight = self.encoder.weight try: self.init_weights() except AttributeError: pass self.rnn_type = rnn_type self.nhid = nhid self.nlayers = nlayers
def __init__(self): torch.nn.Module.__init__(self) self.embed = torch.nn.Embedding(10, 10) self.drnn = drnn.DRNN(10, 10, 9, 'RNN') self.project = torch.nn.Linear(10, 8)
def test(self): model = drnn.DRNN(10, 10, 4, 0, 'GRU') hidden = [] for i in range(4): hidden.append(torch.autograd.Variable(torch.randn(2**i, 3, 10))) x = torch.autograd.Variable(torch.randn(24, 3, 10)) hidden = model(x, hidden)
def test(self): model = drnn.DRNN(10, 10, 4, 0, 'GRU') x = torch.randn(23, 3, 10) out = model(x)[0] self.assertTrue(out.size(0) == 23) self.assertTrue(out.size(1) == 3) self.assertTrue(out.size(2) == 10)
def test(self): model = drnn.DRNN(10, 10, 4, 0, 'GRU') x = torch.autograd.Variable(torch.randn(23, 3, 10)) hidden = model(x)[1] self.assertEqual(len(hidden), 4) for hid in hidden: print(hid.size())
def __init__(self, n_inputs, n_hidden, n_layers, n_classes, cell_type="GRU"): super(Classifier, self).__init__() self.drnn = drnn.DRNN(n_inputs, n_hidden, n_layers, cell_type=cell_type) self.linear = nn.Linear(n_hidden, n_classes)
def test(self): model = drnn.DRNN(10, 10, 4, 0, 'GRU') x = torch.autograd.Variable(torch.randn(24, 3, 10)) split_x = model._prepare_inputs(x, 2) second_block = x[1::2] check = split_x[:, x.size(1):, :] self.assertTrue((second_block == check).all()) unsplit_x = model._split_outputs(split_x, 2) self.assertTrue((x == unsplit_x).all())