Exemplo n.º 1
0
    def __init__(self, input_dim, hidden_dim, diag_dim, nonlinearity='tanh'):
        super(PSD, self).__init__()
        self.diag_dim = diag_dim
        if diag_dim == 1:
            self.linear1 = torch.nn.Linear(input_dim, hidden_dim)
            self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
            self.linear3 = torch.nn.Linear(hidden_dim, diag_dim)

            for l in [self.linear1, self.linear2, self.linear3]:
                torch.nn.init.orthogonal_(
                    l.weight)  # use a principled initialization

            self.nonlinearity = choose_nonlinearity(nonlinearity)
        else:
            assert diag_dim > 1
            self.diag_dim = diag_dim
            self.off_diag_dim = int(diag_dim * (diag_dim - 1) / 2)
            self.linear1 = torch.nn.Linear(input_dim, hidden_dim)
            self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
            self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim)
            self.linear4 = torch.nn.Linear(hidden_dim,
                                           self.diag_dim + self.off_diag_dim)

            for l in [self.linear1, self.linear2, self.linear3, self.linear4]:
                torch.nn.init.orthogonal_(
                    l.weight)  # use a principled initialization

            self.nonlinearity = choose_nonlinearity(nonlinearity)
Exemplo n.º 2
0
    def __init__(self,
                 input_dim,
                 hidden_dim,
                 latent_dim,
                 nonlinearity='tanh',
                 dropout_rate=0.1):
        super(MLPAutoencoder, self).__init__()
        self.linear1 = torch.nn.Linear(input_dim, hidden_dim)
        self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.linear4 = torch.nn.Linear(hidden_dim, latent_dim)

        self.linear5 = torch.nn.Linear(latent_dim, hidden_dim)
        self.linear6 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.linear7 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.linear8 = torch.nn.Linear(hidden_dim, input_dim)

        self.dropout1 = torch.nn.Dropout(p=dropout_rate)
        self.dropout2 = torch.nn.Dropout(p=dropout_rate)
        self.dropout3 = torch.nn.Dropout(p=dropout_rate)

        self.dropout4 = torch.nn.Dropout(p=dropout_rate)
        self.dropout5 = torch.nn.Dropout(p=dropout_rate)
        self.dropout6 = torch.nn.Dropout(p=dropout_rate)

        for l in [self.linear1, self.linear2, self.linear3, self.linear4, \
                  self.linear5, self.linear6, self.linear7, self.linear8]:
            torch.nn.init.orthogonal_(
                l.weight)  # use a principled initialization

        self.nonlinearity = choose_nonlinearity(nonlinearity)
Exemplo n.º 3
0
    def __init__(self, input_dim, hidden_dim, output_dim, nonlinearity='tanh'):
        super(MLP, self).__init__()
        self.linear1 = torch.nn.Linear(input_dim, hidden_dim)
        self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = torch.nn.Linear(hidden_dim, output_dim, bias=None)

        for l in [self.linear1, self.linear2, self.linear3]:
            torch.nn.init.orthogonal_(
                l.weight)  # use a principled initialization

        self.nonlinearity = choose_nonlinearity(nonlinearity)
Exemplo n.º 4
0
    def __init__(self,
                 input_dim,
                 hidden_dim,
                 diag_dim,
                 device,
                 nonlinearity='tanh'):
        super(DampMatrix, self).__init__()
        assert diag_dim > 1
        self.linear1 = torch.nn.Linear(int(input_dim / 2), hidden_dim)
        self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = torch.nn.Linear(hidden_dim, 1, bias=None)

        for l in [self.linear1, self.linear2, self.linear3]:
            torch.nn.init.orthogonal_(
                l.weight)  # use a principled initialization

        self.nonlinearity = choose_nonlinearity(nonlinearity)
        self.device = device