Esempio n. 1
0
 def __init__(self, hidden_size):
     super(Model, self).__init__()
     if args.mode == "lstm":
         self.rnn = nn.LSTMCell(n_input, hidden_size)
     else:
         self.rnn = OrthogonalRNN(n_input, hidden_size, initializer_skew=init, mode=mode, param=param)
     self.lin = nn.Linear(hidden_size, n_classes)
     self.loss_func = nn.MSELoss()
Esempio n. 2
0
    def __init__(self, hidden_size, permute):
        super(Model, self).__init__()
        self.permute = permute
        permute = np.random.RandomState(92916)
        self.register_buffer("permutation", torch.LongTensor(permute.permutation(784)))
        if args.mode == "lstm":
            self.rnn = nn.LSTMCell(1, hidden_size)
        else:
            self.rnn = OrthogonalRNN(1, hidden_size, initializer_skew=init, mode=mode, param=param)

        self.lin = nn.Linear(hidden_size, n_classes)
        self.loss_func = nn.CrossEntropyLoss()
Esempio n. 3
0
    def __init__(self, hidden_size, permute):
        super(Model, self).__init__()
        self.hidden_size = hidden_size
        self.permute = permute
        permute = np.random.RandomState(args.manualSeed)
        self.register_buffer("permutation",
                             torch.LongTensor(permute.permutation(784)))
        if args.mode == "lstm":
            self.rnn = LSTMCell(1, hidden_size)
        elif args.mode == "mlstm":
            self.rnn = MomentumLSTMCell(1,
                                        hidden_size,
                                        mu=args.mu,
                                        epsilon=args.epsilon)
        elif args.mode == "alstm":
            self.rnn = AdamLSTMCell(1,
                                    hidden_size,
                                    mu=args.mu,
                                    epsilon=args.epsilon,
                                    mus=args.mus)
        elif args.mode == "nlstm":
            self.rnn = NesterovLSTMCell(1, hidden_size, epsilon=args.epsilon)
        elif args.mode == "mdtriv":
            self.rnn = OrthogonalMomentumRNN(1,
                                             hidden_size,
                                             initializer_skew=init,
                                             mode=mode,
                                             param=param,
                                             mu=args.mu,
                                             epsilon=args.epsilon)
        elif args.mode == "adtriv":
            self.rnn = OrthogonalAdamRNN(1,
                                         hidden_size,
                                         initializer_skew=init,
                                         mode=mode,
                                         param=param,
                                         mu=args.mu,
                                         epsilon=args.epsilon,
                                         mus=args.mus)
        elif args.mode == "ndtriv":
            self.rnn = OrthogonalNesterovRNN(1,
                                             hidden_size,
                                             initializer_skew=init,
                                             mode=mode,
                                             param=param,
                                             epsilon=args.epsilon)
        else:
            self.rnn = OrthogonalRNN(1,
                                     hidden_size,
                                     initializer_skew=init,
                                     mode=mode,
                                     param=param)

        self.lin = nn.Linear(hidden_size, n_classes)
        self.loss_func = nn.CrossEntropyLoss()
Esempio n. 4
0
 def __init__(self, n_classes, hidden_size):
     super(Model, self).__init__()
     self.hidden_size = hidden_size
     if args.mode == "lstm":
         self.rnn = nn.LSTMCell(n_classes + 1, hidden_size)
     else:
         self.rnn = OrthogonalRNN(n_classes + 1,
                                  hidden_size,
                                  initializer_skew=init,
                                  mode=mode,
                                  param=param)
     self.lin = nn.Linear(hidden_size, n_classes)
     self.loss_func = nn.CrossEntropyLoss()
     self.reset_parameters()
Esempio n. 5
0
class Model(nn.Module):
    def __init__(self, hidden_size):
        super(Model, self).__init__()
        if args.mode == "lstm":
            self.rnn = nn.LSTMCell(n_input, hidden_size)
        else:
            self.rnn = OrthogonalRNN(n_input, hidden_size, initializer_skew=init, mode=mode, param=param)
        self.lin = nn.Linear(hidden_size, n_classes)
        self.loss_func = nn.MSELoss()

    def forward(self, inputs):
        if isinstance(self.rnn, OrthogonalRNN):
            state = self.rnn.default_hidden(inputs[:, 0, ...])
        else:
            state = (torch.zeros((inputs.size(0), self.hidden_size), device=inputs.device),
                     torch.zeros((inputs.size(0), self.hidden_size), device=inputs.device))
        outputs = []
        for input in torch.unbind(inputs, dim=1):
            out_rnn, state = self.rnn(input, state)
            if isinstance(self.rnn, nn.LSTMCell):
                state = (out_rnn, state)
            outputs.append(self.lin(out_rnn))
        return torch.stack(outputs, dim=1)

    def loss(self, logits, y, len_batch):
        return masked_loss(self.loss_func, logits, y, len_batch)
Esempio n. 6
0
class Model(nn.Module):
    def __init__(self, hidden_size, permute):
        super(Model, self).__init__()
        self.permute = permute
        permute = np.random.RandomState(92916)
        self.register_buffer("permutation",
                             torch.LongTensor(permute.permutation(784)))
        if args.mode == "lstm":
            self.rnn = nn.LSTMCell(1, hidden_size)
        else:
            self.rnn = OrthogonalRNN(1,
                                     hidden_size,
                                     skew_initializer=init,
                                     mode=mode,
                                     param=param)

        self.lin = nn.Linear(hidden_size, n_classes)
        self.loss_func = nn.CrossEntropyLoss()

    def forward(self, inputs):
        if self.permute:
            inputs = inputs[:, self.permutation]

        if isinstance(self.rnn, OrthogonalRNN):
            state = self.rnn.default_hidden(inputs[:, 0, ...])
        else:
            state = (torch.zeros((inputs.size(0), self.hidden_size),
                                 device=inputs.device),
                     torch.zeros((inputs.size(0), self.hidden_size),
                                 device=inputs.device))
        for input in torch.unbind(inputs, dim=1):
            out_rnn, state = self.rnn(input.unsqueeze(dim=1), state)
            if isinstance(self.rnn, nn.LSTMCell):
                state = (out_rnn, state)
        return self.lin(state)

    def loss(self, logits, y):
        l = self.loss_func(logits, y)
        if isinstance(self.rnn, OrthogonalRNN):
            return parametrization_trick(model=self, loss=l)
        else:
            return l

    def correct(self, logits, y):
        return torch.eq(torch.argmax(logits, dim=1), y).float().sum()
Esempio n. 7
0
class Model(nn.Module):
    def __init__(self, n_classes, hidden_size):
        super(Model, self).__init__()
        self.hidden_size = hidden_size
        if args.mode == "lstm":
            self.rnn = nn.LSTMCell(n_classes + 1, hidden_size)
        else:
            self.rnn = OrthogonalRNN(n_classes + 1,
                                     hidden_size,
                                     skew_initializer=init,
                                     mode=mode,
                                     param=param)
        self.lin = nn.Linear(hidden_size, n_classes)
        self.loss_func = nn.CrossEntropyLoss()
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_normal_(self.lin.weight.data, nonlinearity="relu")
        nn.init.constant_(self.lin.bias.data, 0)

    def forward(self, inputs):
        if isinstance(self.rnn, OrthogonalRNN):
            state = self.rnn.default_hidden(inputs[:, 0, ...])
        else:
            state = (torch.zeros((inputs.size(0), self.hidden_size),
                                 device=inputs.device),
                     torch.zeros((inputs.size(0), self.hidden_size),
                                 device=inputs.device))
        outputs = []
        for input in torch.unbind(inputs, dim=1):
            out_rnn, state = self.rnn(input, state)
            if isinstance(self.rnn, nn.LSTMCell):
                state = (out_rnn, state)
            outputs.append(self.lin(out_rnn))
        return torch.stack(outputs, dim=1)

    def loss(self, logits, y):
        l = self.loss_func(logits.view(-1, 9), y.view(-1))
        # If the model does not have any OrthogonalRNN (or any Parametrization object) this is is a noop
        return parametrization_trick(model=self, loss=l)

    def accuracy(self, logits, y):
        return torch.eq(torch.argmax(logits, dim=2), y).float().mean()
Esempio n. 8
0
 def __init__(self, hidden_size):
     super(Model, self).__init__()
     self.hidden_size = hidden_size
     if args.mode == "lstm":
         self.rnn = LSTMCell(n_input, hidden_size, fg_init=-4.0)
     elif args.mode == "mlstm":
         self.rnn = MomentumLSTMCell(n_input,
                                     hidden_size,
                                     mu=args.mu,
                                     epsilon=args.epsilon,
                                     fg_init=-4.0)
     elif args.mode == "mdtriv":
         self.rnn = OrthogonalMomentumRNN(n_input,
                                          hidden_size,
                                          initializer_skew=init,
                                          mode=mode,
                                          param=param,
                                          mu=args.mu,
                                          epsilon=args.epsilon)
     elif args.mode == "adtriv":
         self.rnn = OrthogonalAdamRNN(n_input,
                                      hidden_size,
                                      initializer_skew=init,
                                      mode=mode,
                                      param=param,
                                      mu=args.mu,
                                      epsilon=args.epsilon,
                                      mus=args.mus)
     elif args.mode == "ndtriv":
         self.rnn = OrthogonalNesterovRNN(n_input,
                                          hidden_size,
                                          initializer_skew=init,
                                          mode=mode,
                                          param=param,
                                          epsilon=args.epsilon)
     else:
         self.rnn = OrthogonalRNN(n_input,
                                  hidden_size,
                                  initializer_skew=init,
                                  mode=mode,
                                  param=param)
     self.lin = nn.Linear(hidden_size, n_classes)
     self.loss_func = nn.MSELoss()
Esempio n. 9
0
class Model(nn.Module):
    def __init__(self, hidden_size, permute):
        super(Model, self).__init__()
        self.hidden_size = hidden_size
        self.permute = permute
        permute = np.random.RandomState(args.manualSeed)
        self.register_buffer("permutation",
                             torch.LongTensor(permute.permutation(784)))
        if args.mode == "lstm":
            self.rnn = LSTMCell(1, hidden_size)
        elif args.mode == "mlstm":
            self.rnn = MomentumLSTMCell(1,
                                        hidden_size,
                                        mu=args.mu,
                                        epsilon=args.epsilon)
        elif args.mode == "alstm":
            self.rnn = AdamLSTMCell(1,
                                    hidden_size,
                                    mu=args.mu,
                                    epsilon=args.epsilon,
                                    mus=args.mus)
        elif args.mode == "nlstm":
            self.rnn = NesterovLSTMCell(1, hidden_size, epsilon=args.epsilon)
        elif args.mode == "mdtriv":
            self.rnn = OrthogonalMomentumRNN(1,
                                             hidden_size,
                                             initializer_skew=init,
                                             mode=mode,
                                             param=param,
                                             mu=args.mu,
                                             epsilon=args.epsilon)
        elif args.mode == "adtriv":
            self.rnn = OrthogonalAdamRNN(1,
                                         hidden_size,
                                         initializer_skew=init,
                                         mode=mode,
                                         param=param,
                                         mu=args.mu,
                                         epsilon=args.epsilon,
                                         mus=args.mus)
        elif args.mode == "ndtriv":
            self.rnn = OrthogonalNesterovRNN(1,
                                             hidden_size,
                                             initializer_skew=init,
                                             mode=mode,
                                             param=param,
                                             epsilon=args.epsilon)
        else:
            self.rnn = OrthogonalRNN(1,
                                     hidden_size,
                                     initializer_skew=init,
                                     mode=mode,
                                     param=param)

        self.lin = nn.Linear(hidden_size, n_classes)
        self.loss_func = nn.CrossEntropyLoss()

    def forward(self, inputs):
        if self.permute:
            inputs = inputs[:, self.permutation]

        if isinstance(self.rnn, OrthogonalRNN) or isinstance(
                self.rnn, OrthogonalMomentumRNN) or isinstance(
                    self.rnn, OrthogonalAdamRNN) or isinstance(
                        self.rnn, OrthogonalNesterovRNN):
            state = self.rnn.default_hidden(inputs[:, 0, ...])
        else:
            state = (torch.zeros((inputs.size(0), self.hidden_size),
                                 device=inputs.device),
                     torch.zeros((inputs.size(0), self.hidden_size),
                                 device=inputs.device))

        if isinstance(self.rnn, MomentumLSTMCell) or isinstance(
                self.rnn, NesterovLSTMCell):
            v = torch.zeros((inputs.size(0), 4 * self.hidden_size),
                            device=inputs.device)
        elif isinstance(self.rnn, AdamLSTMCell):
            v = torch.zeros((inputs.size(0), 4 * self.hidden_size),
                            device=inputs.device)
            s = torch.zeros((inputs.size(0), 4 * self.hidden_size),
                            device=inputs.device)
        elif isinstance(self.rnn, OrthogonalAdamRNN):
            v = torch.zeros((inputs.size(0), self.hidden_size),
                            device=inputs.device)
            s = torch.zeros((inputs.size(0), self.hidden_size),
                            device=inputs.device)
        elif isinstance(self.rnn, OrthogonalMomentumRNN) or isinstance(
                self.rnn, OrthogonalNesterovRNN):
            v = torch.zeros((inputs.size(0), self.hidden_size),
                            device=inputs.device)

        iter_indx = 0
        for input in torch.unbind(inputs, dim=1):
            iter_indx = iter_indx + 1
            if isinstance(self.rnn, MomentumLSTMCell) or isinstance(
                    self.rnn, OrthogonalMomentumRNN):
                out_rnn, state, v = self.rnn(input.unsqueeze(dim=1), state, v)
            elif isinstance(self.rnn, AdamLSTMCell) or isinstance(
                    self.rnn, OrthogonalAdamRNN):
                out_rnn, state, v, s = self.rnn(input.unsqueeze(dim=1), state,
                                                v, s)
            elif isinstance(self.rnn, NesterovLSTMCell) or isinstance(
                    self.rnn, OrthogonalNesterovRNN):
                out_rnn, state, v = self.rnn(input.unsqueeze(dim=1),
                                             state,
                                             v,
                                             k=iter_indx)
                if args.restart > 0 and not (iter_indx % args.restart):
                    iter_indx = 0
            else:
                out_rnn, state = self.rnn(input.unsqueeze(dim=1), state)

        return self.lin(out_rnn)

    def loss(self, logits, y):
        return self.loss_func(logits, y)

    def correct(self, logits, y):
        return torch.eq(torch.argmax(logits, dim=1), y).float().sum()
Esempio n. 10
0
class Model(nn.Module):
    def __init__(self, hidden_size):
        super(Model, self).__init__()
        self.hidden_size = hidden_size
        if args.mode == "lstm":
            self.rnn = LSTMCell(n_input, hidden_size, fg_init=-4.0)
        elif args.mode == "mlstm":
            self.rnn = MomentumLSTMCell(n_input,
                                        hidden_size,
                                        mu=args.mu,
                                        epsilon=args.epsilon,
                                        fg_init=-4.0)
        elif args.mode == "alstm":
            self.rnn = AdamLSTMCell(n_input,
                                    hidden_size,
                                    mu=args.mu,
                                    epsilon=args.epsilon,
                                    mus=args.mus,
                                    fg_init=-4.0)
        elif args.mode == "nlstm":
            self.rnn = NesterovLSTMCell(n_input,
                                        hidden_size,
                                        epsilon=args.epsilon,
                                        fg_init=-4.0)
        elif args.mode == "mdtriv":
            self.rnn = OrthogonalMomentumRNN(n_input,
                                             hidden_size,
                                             initializer_skew=init,
                                             mode=mode,
                                             param=param,
                                             mu=args.mu,
                                             epsilon=args.epsilon)
        else:
            self.rnn = OrthogonalRNN(n_input,
                                     hidden_size,
                                     initializer_skew=init,
                                     mode=mode,
                                     param=param)
        self.lin = nn.Linear(hidden_size, n_classes)
        self.loss_func = nn.MSELoss()

    def forward(self, inputs):
        if isinstance(self.rnn, OrthogonalRNN) or isinstance(
                self.rnn, OrthogonalMomentumRNN):
            state = self.rnn.default_hidden(inputs[:, 0, ...])
        else:
            state = (torch.zeros((inputs.size(0), self.hidden_size),
                                 device=inputs.device),
                     torch.zeros((inputs.size(0), self.hidden_size),
                                 device=inputs.device))

        if isinstance(self.rnn, MomentumLSTMCell) or isinstance(
                self.rnn, NesterovLSTMCell):
            v = torch.zeros((inputs.size(0), 4 * self.hidden_size),
                            device=inputs.device)
        elif isinstance(self.rnn, AdamLSTMCell):
            v = torch.zeros((inputs.size(0), 4 * self.hidden_size),
                            device=inputs.device)
            s = torch.zeros((inputs.size(0), 4 * self.hidden_size),
                            device=inputs.device)
        elif isinstance(self.rnn, OrthogonalMomentumRNN):
            v = torch.zeros((inputs.size(0), self.hidden_size),
                            device=inputs.device)

        outputs = []
        iter_indx = 0
        for input in torch.unbind(inputs, dim=1):
            iter_indx = iter_indx + 1
            if isinstance(self.rnn, MomentumLSTMCell) or isinstance(
                    self.rnn, OrthogonalMomentumRNN):
                out_rnn, state, v = self.rnn(input, state, v)
            elif isinstance(self.rnn, AdamLSTMCell):
                out_rnn, state, v, s = self.rnn(input, state, v, s)
            elif isinstance(self.rnn, NesterovLSTMCell):
                out_rnn, state, v = self.rnn(input, state, v, k=iter_indx)
                if args.restart > 0 and not (iter_indx % args.restart):
                    iter_indx = 0
            else:
                out_rnn, state = self.rnn(input, state)

            outputs.append(self.lin(out_rnn))
        return torch.stack(outputs, dim=1)

    def loss(self, logits, y, len_batch):
        return masked_loss(self.loss_func, logits, y, len_batch)