def __init__(self, hidden_size): super(Model, self).__init__() if args.mode == "lstm": self.rnn = nn.LSTMCell(n_input, hidden_size) else: self.rnn = OrthogonalRNN(n_input, hidden_size, initializer_skew=init, mode=mode, param=param) self.lin = nn.Linear(hidden_size, n_classes) self.loss_func = nn.MSELoss()
def __init__(self, hidden_size, permute): super(Model, self).__init__() self.permute = permute permute = np.random.RandomState(92916) self.register_buffer("permutation", torch.LongTensor(permute.permutation(784))) if args.mode == "lstm": self.rnn = nn.LSTMCell(1, hidden_size) else: self.rnn = OrthogonalRNN(1, hidden_size, initializer_skew=init, mode=mode, param=param) self.lin = nn.Linear(hidden_size, n_classes) self.loss_func = nn.CrossEntropyLoss()
def __init__(self, hidden_size, permute): super(Model, self).__init__() self.hidden_size = hidden_size self.permute = permute permute = np.random.RandomState(args.manualSeed) self.register_buffer("permutation", torch.LongTensor(permute.permutation(784))) if args.mode == "lstm": self.rnn = LSTMCell(1, hidden_size) elif args.mode == "mlstm": self.rnn = MomentumLSTMCell(1, hidden_size, mu=args.mu, epsilon=args.epsilon) elif args.mode == "alstm": self.rnn = AdamLSTMCell(1, hidden_size, mu=args.mu, epsilon=args.epsilon, mus=args.mus) elif args.mode == "nlstm": self.rnn = NesterovLSTMCell(1, hidden_size, epsilon=args.epsilon) elif args.mode == "mdtriv": self.rnn = OrthogonalMomentumRNN(1, hidden_size, initializer_skew=init, mode=mode, param=param, mu=args.mu, epsilon=args.epsilon) elif args.mode == "adtriv": self.rnn = OrthogonalAdamRNN(1, hidden_size, initializer_skew=init, mode=mode, param=param, mu=args.mu, epsilon=args.epsilon, mus=args.mus) elif args.mode == "ndtriv": self.rnn = OrthogonalNesterovRNN(1, hidden_size, initializer_skew=init, mode=mode, param=param, epsilon=args.epsilon) else: self.rnn = OrthogonalRNN(1, hidden_size, initializer_skew=init, mode=mode, param=param) self.lin = nn.Linear(hidden_size, n_classes) self.loss_func = nn.CrossEntropyLoss()
def __init__(self, n_classes, hidden_size): super(Model, self).__init__() self.hidden_size = hidden_size if args.mode == "lstm": self.rnn = nn.LSTMCell(n_classes + 1, hidden_size) else: self.rnn = OrthogonalRNN(n_classes + 1, hidden_size, initializer_skew=init, mode=mode, param=param) self.lin = nn.Linear(hidden_size, n_classes) self.loss_func = nn.CrossEntropyLoss() self.reset_parameters()
class Model(nn.Module): def __init__(self, hidden_size): super(Model, self).__init__() if args.mode == "lstm": self.rnn = nn.LSTMCell(n_input, hidden_size) else: self.rnn = OrthogonalRNN(n_input, hidden_size, initializer_skew=init, mode=mode, param=param) self.lin = nn.Linear(hidden_size, n_classes) self.loss_func = nn.MSELoss() def forward(self, inputs): if isinstance(self.rnn, OrthogonalRNN): state = self.rnn.default_hidden(inputs[:, 0, ...]) else: state = (torch.zeros((inputs.size(0), self.hidden_size), device=inputs.device), torch.zeros((inputs.size(0), self.hidden_size), device=inputs.device)) outputs = [] for input in torch.unbind(inputs, dim=1): out_rnn, state = self.rnn(input, state) if isinstance(self.rnn, nn.LSTMCell): state = (out_rnn, state) outputs.append(self.lin(out_rnn)) return torch.stack(outputs, dim=1) def loss(self, logits, y, len_batch): return masked_loss(self.loss_func, logits, y, len_batch)
class Model(nn.Module): def __init__(self, hidden_size, permute): super(Model, self).__init__() self.permute = permute permute = np.random.RandomState(92916) self.register_buffer("permutation", torch.LongTensor(permute.permutation(784))) if args.mode == "lstm": self.rnn = nn.LSTMCell(1, hidden_size) else: self.rnn = OrthogonalRNN(1, hidden_size, skew_initializer=init, mode=mode, param=param) self.lin = nn.Linear(hidden_size, n_classes) self.loss_func = nn.CrossEntropyLoss() def forward(self, inputs): if self.permute: inputs = inputs[:, self.permutation] if isinstance(self.rnn, OrthogonalRNN): state = self.rnn.default_hidden(inputs[:, 0, ...]) else: state = (torch.zeros((inputs.size(0), self.hidden_size), device=inputs.device), torch.zeros((inputs.size(0), self.hidden_size), device=inputs.device)) for input in torch.unbind(inputs, dim=1): out_rnn, state = self.rnn(input.unsqueeze(dim=1), state) if isinstance(self.rnn, nn.LSTMCell): state = (out_rnn, state) return self.lin(state) def loss(self, logits, y): l = self.loss_func(logits, y) if isinstance(self.rnn, OrthogonalRNN): return parametrization_trick(model=self, loss=l) else: return l def correct(self, logits, y): return torch.eq(torch.argmax(logits, dim=1), y).float().sum()
class Model(nn.Module): def __init__(self, n_classes, hidden_size): super(Model, self).__init__() self.hidden_size = hidden_size if args.mode == "lstm": self.rnn = nn.LSTMCell(n_classes + 1, hidden_size) else: self.rnn = OrthogonalRNN(n_classes + 1, hidden_size, skew_initializer=init, mode=mode, param=param) self.lin = nn.Linear(hidden_size, n_classes) self.loss_func = nn.CrossEntropyLoss() self.reset_parameters() def reset_parameters(self): nn.init.kaiming_normal_(self.lin.weight.data, nonlinearity="relu") nn.init.constant_(self.lin.bias.data, 0) def forward(self, inputs): if isinstance(self.rnn, OrthogonalRNN): state = self.rnn.default_hidden(inputs[:, 0, ...]) else: state = (torch.zeros((inputs.size(0), self.hidden_size), device=inputs.device), torch.zeros((inputs.size(0), self.hidden_size), device=inputs.device)) outputs = [] for input in torch.unbind(inputs, dim=1): out_rnn, state = self.rnn(input, state) if isinstance(self.rnn, nn.LSTMCell): state = (out_rnn, state) outputs.append(self.lin(out_rnn)) return torch.stack(outputs, dim=1) def loss(self, logits, y): l = self.loss_func(logits.view(-1, 9), y.view(-1)) # If the model does not have any OrthogonalRNN (or any Parametrization object) this is is a noop return parametrization_trick(model=self, loss=l) def accuracy(self, logits, y): return torch.eq(torch.argmax(logits, dim=2), y).float().mean()
def __init__(self, hidden_size): super(Model, self).__init__() self.hidden_size = hidden_size if args.mode == "lstm": self.rnn = LSTMCell(n_input, hidden_size, fg_init=-4.0) elif args.mode == "mlstm": self.rnn = MomentumLSTMCell(n_input, hidden_size, mu=args.mu, epsilon=args.epsilon, fg_init=-4.0) elif args.mode == "mdtriv": self.rnn = OrthogonalMomentumRNN(n_input, hidden_size, initializer_skew=init, mode=mode, param=param, mu=args.mu, epsilon=args.epsilon) elif args.mode == "adtriv": self.rnn = OrthogonalAdamRNN(n_input, hidden_size, initializer_skew=init, mode=mode, param=param, mu=args.mu, epsilon=args.epsilon, mus=args.mus) elif args.mode == "ndtriv": self.rnn = OrthogonalNesterovRNN(n_input, hidden_size, initializer_skew=init, mode=mode, param=param, epsilon=args.epsilon) else: self.rnn = OrthogonalRNN(n_input, hidden_size, initializer_skew=init, mode=mode, param=param) self.lin = nn.Linear(hidden_size, n_classes) self.loss_func = nn.MSELoss()
class Model(nn.Module): def __init__(self, hidden_size, permute): super(Model, self).__init__() self.hidden_size = hidden_size self.permute = permute permute = np.random.RandomState(args.manualSeed) self.register_buffer("permutation", torch.LongTensor(permute.permutation(784))) if args.mode == "lstm": self.rnn = LSTMCell(1, hidden_size) elif args.mode == "mlstm": self.rnn = MomentumLSTMCell(1, hidden_size, mu=args.mu, epsilon=args.epsilon) elif args.mode == "alstm": self.rnn = AdamLSTMCell(1, hidden_size, mu=args.mu, epsilon=args.epsilon, mus=args.mus) elif args.mode == "nlstm": self.rnn = NesterovLSTMCell(1, hidden_size, epsilon=args.epsilon) elif args.mode == "mdtriv": self.rnn = OrthogonalMomentumRNN(1, hidden_size, initializer_skew=init, mode=mode, param=param, mu=args.mu, epsilon=args.epsilon) elif args.mode == "adtriv": self.rnn = OrthogonalAdamRNN(1, hidden_size, initializer_skew=init, mode=mode, param=param, mu=args.mu, epsilon=args.epsilon, mus=args.mus) elif args.mode == "ndtriv": self.rnn = OrthogonalNesterovRNN(1, hidden_size, initializer_skew=init, mode=mode, param=param, epsilon=args.epsilon) else: self.rnn = OrthogonalRNN(1, hidden_size, initializer_skew=init, mode=mode, param=param) self.lin = nn.Linear(hidden_size, n_classes) self.loss_func = nn.CrossEntropyLoss() def forward(self, inputs): if self.permute: inputs = inputs[:, self.permutation] if isinstance(self.rnn, OrthogonalRNN) or isinstance( self.rnn, OrthogonalMomentumRNN) or isinstance( self.rnn, OrthogonalAdamRNN) or isinstance( self.rnn, OrthogonalNesterovRNN): state = self.rnn.default_hidden(inputs[:, 0, ...]) else: state = (torch.zeros((inputs.size(0), self.hidden_size), device=inputs.device), torch.zeros((inputs.size(0), self.hidden_size), device=inputs.device)) if isinstance(self.rnn, MomentumLSTMCell) or isinstance( self.rnn, NesterovLSTMCell): v = torch.zeros((inputs.size(0), 4 * self.hidden_size), device=inputs.device) elif isinstance(self.rnn, AdamLSTMCell): v = torch.zeros((inputs.size(0), 4 * self.hidden_size), device=inputs.device) s = torch.zeros((inputs.size(0), 4 * self.hidden_size), device=inputs.device) elif isinstance(self.rnn, OrthogonalAdamRNN): v = torch.zeros((inputs.size(0), self.hidden_size), device=inputs.device) s = torch.zeros((inputs.size(0), self.hidden_size), device=inputs.device) elif isinstance(self.rnn, OrthogonalMomentumRNN) or isinstance( self.rnn, OrthogonalNesterovRNN): v = torch.zeros((inputs.size(0), self.hidden_size), device=inputs.device) iter_indx = 0 for input in torch.unbind(inputs, dim=1): iter_indx = iter_indx + 1 if isinstance(self.rnn, MomentumLSTMCell) or isinstance( self.rnn, OrthogonalMomentumRNN): out_rnn, state, v = self.rnn(input.unsqueeze(dim=1), state, v) elif isinstance(self.rnn, AdamLSTMCell) or isinstance( self.rnn, OrthogonalAdamRNN): out_rnn, state, v, s = self.rnn(input.unsqueeze(dim=1), state, v, s) elif isinstance(self.rnn, NesterovLSTMCell) or isinstance( self.rnn, OrthogonalNesterovRNN): out_rnn, state, v = self.rnn(input.unsqueeze(dim=1), state, v, k=iter_indx) if args.restart > 0 and not (iter_indx % args.restart): iter_indx = 0 else: out_rnn, state = self.rnn(input.unsqueeze(dim=1), state) return self.lin(out_rnn) def loss(self, logits, y): return self.loss_func(logits, y) def correct(self, logits, y): return torch.eq(torch.argmax(logits, dim=1), y).float().sum()
class Model(nn.Module): def __init__(self, hidden_size): super(Model, self).__init__() self.hidden_size = hidden_size if args.mode == "lstm": self.rnn = LSTMCell(n_input, hidden_size, fg_init=-4.0) elif args.mode == "mlstm": self.rnn = MomentumLSTMCell(n_input, hidden_size, mu=args.mu, epsilon=args.epsilon, fg_init=-4.0) elif args.mode == "alstm": self.rnn = AdamLSTMCell(n_input, hidden_size, mu=args.mu, epsilon=args.epsilon, mus=args.mus, fg_init=-4.0) elif args.mode == "nlstm": self.rnn = NesterovLSTMCell(n_input, hidden_size, epsilon=args.epsilon, fg_init=-4.0) elif args.mode == "mdtriv": self.rnn = OrthogonalMomentumRNN(n_input, hidden_size, initializer_skew=init, mode=mode, param=param, mu=args.mu, epsilon=args.epsilon) else: self.rnn = OrthogonalRNN(n_input, hidden_size, initializer_skew=init, mode=mode, param=param) self.lin = nn.Linear(hidden_size, n_classes) self.loss_func = nn.MSELoss() def forward(self, inputs): if isinstance(self.rnn, OrthogonalRNN) or isinstance( self.rnn, OrthogonalMomentumRNN): state = self.rnn.default_hidden(inputs[:, 0, ...]) else: state = (torch.zeros((inputs.size(0), self.hidden_size), device=inputs.device), torch.zeros((inputs.size(0), self.hidden_size), device=inputs.device)) if isinstance(self.rnn, MomentumLSTMCell) or isinstance( self.rnn, NesterovLSTMCell): v = torch.zeros((inputs.size(0), 4 * self.hidden_size), device=inputs.device) elif isinstance(self.rnn, AdamLSTMCell): v = torch.zeros((inputs.size(0), 4 * self.hidden_size), device=inputs.device) s = torch.zeros((inputs.size(0), 4 * self.hidden_size), device=inputs.device) elif isinstance(self.rnn, OrthogonalMomentumRNN): v = torch.zeros((inputs.size(0), self.hidden_size), device=inputs.device) outputs = [] iter_indx = 0 for input in torch.unbind(inputs, dim=1): iter_indx = iter_indx + 1 if isinstance(self.rnn, MomentumLSTMCell) or isinstance( self.rnn, OrthogonalMomentumRNN): out_rnn, state, v = self.rnn(input, state, v) elif isinstance(self.rnn, AdamLSTMCell): out_rnn, state, v, s = self.rnn(input, state, v, s) elif isinstance(self.rnn, NesterovLSTMCell): out_rnn, state, v = self.rnn(input, state, v, k=iter_indx) if args.restart > 0 and not (iter_indx % args.restart): iter_indx = 0 else: out_rnn, state = self.rnn(input, state) outputs.append(self.lin(out_rnn)) return torch.stack(outputs, dim=1) def loss(self, logits, y, len_batch): return masked_loss(self.loss_func, logits, y, len_batch)