Beispiel #1
0
class TiedLinear(torch.nn.Module):
    def __init__(self, in_features, output_dim, bias=False):
        super(TiedLinear, self).__init__()
        self.in_features = in_features
        self.output_dim = output_dim
        self.weight = Parameter(torch.Tensor(1, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(1, in_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
        self.W = self.weight.expand(output_dim, -1)
        if self.bias is not None:
            self.B = self.bias.expand(output_dim, -1)

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(0))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, X, index, mask):
        output = X.mul(self.W)
        if self.bias is not None:
            output += self.B
        output = output.sum(2)
        output.index_add_(0, index, mask)
        return output
class BiasedEmbedding(nn.Module):
    def __init__(self, n_feat, n_dim, lv=1.0, lb=1.0):
        super(BiasedEmbedding, self).__init__()
        self.vect = nn.Embedding(n_feat, n_dim)
        self.bias = nn.Embedding(n_feat, 1)
        self.off_vect = Parameter(torch.zeros(1, n_dim))
        self.mul_vect = Parameter(torch.ones(1, n_dim))
        self.off_bias = Parameter(torch.zeros(1))
        self.mul_bias = Parameter(torch.ones(1))
        self.n_dim = n_dim
        self.n_feat = n_feat
        self.lv = lv
        self.lb = lb
        self.vect.weight.data.normal_(0, 1.0 / n_dim)
        self.bias.weight.data.normal_(0, 1.0 / n_dim)

    def __call__(self, index):
        assert (index.max() < self.n_feat).all()
        assert (index.min() >= 0).all()
        off_vect = self.off_vect.expand(len(index), self.n_dim).squeeze()
        off_bias = self.off_bias.expand(len(index), 1).squeeze()
        bias = off_bias + self.mul_bias * self.bias(index).squeeze()
        vect = off_vect + self.mul_vect * self.vect(index)
        return bias, vect

    def prior(self):
        loss = ((self.vect.weight**2.0).sum() * self.lv +
                (self.bias.weight**2.0).sum() * self.lb)
        return loss
Beispiel #3
0
class RNN(nn.Module):
    def __init__(self, input_size=9, hidden_size=100, output_size=9):
        super(RNN, self).__init__()
        MAX = 4000
        EOS = torch.from_numpy(np.array(8 * [0] + [1])).float()
        self.hidden_size = hidden_size
        self.LSTM = nn.LSTMCell(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)
        self.activation = functional.sigmoid
        self.hidden_state0 = Parameter(torch.zeros(1, hidden_size)).float()
        self.cell_state0 = Parameter(torch.zeros(1, hidden_size)).float()
        self.zero_vector = Parameter(torch.zeros(MAX, 9)).float()
        # self.zero_vector = Parameter(EOS.expand(MAX, 9))

    def step(self, input_vector, hidden_state, cell_state):
        hidden_state, cell_state = self.LSTM(input_vector, (hidden_state, cell_state))
        return hidden_state, cell_state, self.fc(hidden_state)

    def forward(self, input_vectors):
        N = input_vectors.shape[0]
        T = input_vectors.shape[1] - 1

        hidden_state = self.hidden_state0.expand(N, self.hidden_size)
        cell_state = self.cell_state0.expand(N, self.hidden_size)

        for t in range(T + 1):
            hidden_state, cell_state, _ = self.step(input_vectors[:, t, :], hidden_state, cell_state)

        outputs = []
        for t in range(T):
            hidden_state, cell_state, output = self.step(self.zero_vector[:N, :], hidden_state, cell_state)
            outputs.append(self.activation(output.unsqueeze(2).transpose(1, 2)))
        return torch.cat(outputs, 1)
Beispiel #4
0
class BasisNet(torch.nn.Module):
    def __init__(self, include_G_weights):
        super(BasisNet, self).__init__()

        self.include_G_weights = include_G_weights

        self.psd_vec = torch.nn.Linear(2, 3, bias=False)
        self.psd_vec_bias = Parameter(torch.ones(3))
        self.nonneg_matrix = torch.nn.Linear(2, 9, bias=False)
        self.nonneg_matrix_bias = Parameter(torch.ones(9))
        self.antisym = torch.nn.Linear(2, 3, bias=False)
        self.antisym_bias = Parameter(torch.ones(3))

        self.antisym_basis = torch.tensor([[[0, 1, 0], [-1, 0, 0], [0, 0, 0]],
                                           [[0, 0, 1], [0, 0, 0], [-1, 0, 0]],
                                           [[0, 0, 0], [0, 0, 1],
                                            [0, -1, 0]]]).float()

        self.f = torch.nn.Linear(2, 3, bias=True)

        if not include_G_weights:
            self.psd_vec.weight.data.fill_(0)
            self.nonneg_matrix.weight.data.fill_(0)
            self.antisym.weight.data.fill_(0)

    def forward(self, states):
        lambdas = states[:, 2:5]
        Gxu, fxu = self.get_lcps(states)

        lcp_slack = fxu + torch.bmm(Gxu, lambdas.unsqueeze(2)).squeeze(2)

        return lcp_slack

    def get_lcps(self, states):
        lambdas = states[:, 2:5]
        xus = states[:, 0:2]

        fxu = self.f(xus)

        if self.include_G_weights:
            psd_vec = self.psd_vec(xus) + self.psd_vec_bias
            nonneg_matrix = self.nonneg_matrix(xus) + self.nonneg_matrix_bias
            antisym_vec = self.antisym(xus) + self.antisym_bias
        else:
            psd_vec = self.psd_vec_bias.expand(states.shape[0], 3)
            nonneg_matrix = self.nonneg_matrix_bias.expand(states.shape[0], 9)
            antisym_vec = self.antisym_bias.expand(states.shape[0], 3)

        psd_term = torch.bmm(psd_vec.unsqueeze(2), psd_vec.unsqueeze(1))

        nonneg_term = F.relu(nonneg_matrix.view(-1, 3, 3))

        antisym_vec = antisym_vec.unsqueeze(2).unsqueeze(3)
        antisym_basis = self.antisym_basis.expand(antisym_vec.shape[0], 3, 3,
                                                  3)
        antisym_term = torch.sum(antisym_vec * antisym_basis, dim=1)

        Gxu = psd_term + nonneg_term + antisym_term

        return Gxu, fxu
Beispiel #5
0
class AttentionModule(nn.Module):
    def __init__(self, config):
        super(AttentionModule, self).__init__()
        self.config = config
        self.rnn = nn.LSTM(config.word_emb_dim + config.resnet_features,
        #self.rnn = nn.LSTM(config.word_emb_dim,
                           config.lstm_hidden_size, config.recurrent_layers,
                           batch_first=False, bidirectional=False, dropout=config.recurrent_dropout)
        self.weight_feature2attn = Parameter(torch.Tensor(config.resnet_features))
        self.weight_hidden2attn = Parameter(torch.Tensor(config.lstm_hidden_size, config.resnet_fmap_size))
        self.bias_attention = Parameter(torch.Tensor(config.resnet_fmap_size))
        self.mlp = nn.Sequential(
            nn.Linear(config.lstm_hidden_size, config.mlp_hidden_size),
            #nn.Linear(config.lstm_hidden_size+config.resnet_features, config.mlp_hidden_size),
            nn.ReLU(),
            nn.Dropout(config.mlp_dropout),
            nn.Linear(config.mlp_hidden_size, 1),
            nn.Sigmoid()
        )
        self.reset_parameters()

    def reset_parameters(self):
        for weight in self.parameters():
            stdv = 2.0 / math.sqrt(weight.size(-1))
            weight.data.uniform_(-stdv, stdv)

    def forward(self, rawwords_emb, image_features):
        words_emb, lengths = pad_packed_sequence(rawwords_emb, batch_first=False)
        seq_len, batch_size = words_emb.size()[0:2]
        hidden_state = Variable(torch.zeros(self.config.recurrent_layers, batch_size, self.config.lstm_hidden_size))
        cell_state = Variable(torch.zeros(self.config.recurrent_layers, batch_size, self.config.lstm_hidden_size))
        if self.config.cuda:
            hidden_state, cell_state = hidden_state.cuda(), cell_state.cuda()
        weight = self.weight_feature2attn.expand(batch_size, self.config.resnet_features).unsqueeze(2)
        feature2a = torch.bmm(image_features.transpose(1, 2), weight).squeeze(2)
        outputs = []
        for i in range(seq_len):
            # feature2a, attn_weight are both (batch_size, 512)
            attn_weight = F.softmax(feature2a + torch.mm(hidden_state[-1], self.weight_hidden2attn) +
                    self.bias_attention.expand(batch_size, self.config.resnet_fmap_size))
            attn_out = torch.bmm(image_features, attn_weight.unsqueeze(2)).squeeze(2) # (bz, 512)
            inputs = torch.cat([words_emb[i], attn_out], 1).unsqueeze(0)
            output, (hidden_state, cell_state) = self.rnn(inputs, (hidden_state, cell_state))
            outputs.append(output.squeeze(0))
        outputs = torch.cat([outputs[lengths[i]-1][i].unsqueeze(0) for i in range(len(lengths))], 0)
        scores = self.mlp(outputs).squeeze(1)
        '''
        outputs, (hidden_state, cell_state) = self.rnn(rawwords_emb)
        inputs = torch.cat([hidden_state[-1], image_features.mean(2).squeeze(2)], 1)
        scores = self.mlp(inputs).squeeze(1)
        '''
        return scores

    def eval(self):
        self.rnn.eval()
        self.mlp.eval()

    def train(self):
        self.rnn.train()
        self.mlp.train()
Beispiel #6
0
 def initial_state_means_for_batch(self, parameters: Parameter,
                                   num_groups: int, **kwargs) -> Tensor:
     """
     Most children should use default. Handles rearranging of state-means based on for_batch keyword args. E.g. a
     discrete seasonal process w/ a state-element for each season would need to know on which season the batch starts
     """
     return parameters.expand(num_groups, -1)
Beispiel #7
0
class Affine(Transform):
    def __init__(self, loc=0.0, scale=1.0, learnable=True):
        super().__init__()
        if not isinstance(loc, torch.Tensor):
            loc = torch.tensor(loc).view(1, -1)
        if not isinstance(scale, torch.Tensor):
            scale = torch.tensor(scale).view(1, -1)
        self.loc = loc.float()
        self.scale = scale.float()
        self.n_dims = len(loc)
        if learnable:
            self.loc = Parameter(self.loc)
            self.scale = Parameter(self.scale)

    def forward(self, x):
        return self.loc + self.scale * x

    def inverse(self, y):
        return (y - self.loc) / self.scale

    def log_abs_det_jacobian(self, x, y):
        return torch.log(torch.abs(self.scale.expand(x.size()))).sum(-1)

    def get_parameters(self):
        return {
            'type': 'affine',
            'loc': self.loc.detach().numpy(),
            'scale': self.scale.detach().numpy()
        }
Beispiel #8
0
class SqueezedAttFeatTrans(nn.Module):
    def __init__(self, config, name):
        super(SqueezedAttFeatTrans, self).__init__()
        self.config = config
        self.name = name
        self.in_feat_dim = config.in_feat_dim
        self.num_attractors = config.num_attractors

        # Disable feature squeezing in in_ator_trans.
        config1 = copy.copy(config)
        config1.feat_dim = config1.in_feat_dim
        config1.num_modes = 1
        self.in_ator_trans = CrossAttFeatTrans(config1, name + '-in-squeeze')
        self.ator_out_trans = CrossAttFeatTrans(config, name + '-squeeze-out')
        self.attractors = Parameter(
            torch.randn(1, self.num_attractors, self.in_feat_dim))

    def forward(self, in_feat, attention_mask=None):
        # in_feat: [B, 196, 1792]
        batch_size = in_feat.shape[0]
        batch_attractors = self.attractors.expand(batch_size, -1, -1)
        new_batch_attractors = self.in_ator_trans(batch_attractors, in_feat)
        out_feat = self.ator_out_trans(in_feat, new_batch_attractors)
        self.attention_scores = self.ator_out_trans.attention_scores
        return out_feat
class ReparamNormal_Mu_Logvar(ReparamNormal):
    def __init__(self, output_dim=2, **kwargs):
        super().__init__()
        mu = kwargs.get('mu', torch.zeros(output_dim))
        self.mu_param = Parameter(mu.unsqueeze(0))
        logvar = kwargs.get('logvar', torch.zeros(output_dim))
        self.logvar_param = Parameter(logvar.unsqueeze(0))

    def forward(self, x):
        s = x.size(0)
        o = self.mu_param.size(1)
        return self.mu_param.expand(s, o), self.logvar_param.expand(s, o)

    def __repr__(self):
        return super().__repr__() + \
            '(_ -> {})'.format(self.mu_param.size(0))
Beispiel #10
0
class ConstantDiscountRateLayer(DiscountLayer):
    def __init__(self, forward_rate, time_len=MAX_YR_LEN):
        Module.__init__(self)
        self.time_len = time_len
        if isinstance(forward_rate, float):
            forward_rate = torch.tensor([forward_rate])
        self._forward_rate = Parameter(forward_rate)
        self.forward_rate = self._forward_rate.expand(time_len)
Beispiel #11
0
class Rnn(Module):
    """ A BiLSTM or BiGRU """
    def __init__(self, input_dim, hidden_dim, cell_type=LSTM, gpu=False):
        super(Rnn, self).__init__()
        self.hidden_dim = hidden_dim
        self.to_cuda = to_cuda(gpu)
        self.input_dim = input_dim
        self.cell_type = cell_type
        self.rnn = self.cell_type(input_size=self.input_dim,
                                  hidden_size=hidden_dim,
                                  num_layers=1,
                                  bidirectional=True)

        self.num_directions = 2  # We're a *bi*LSTM
        self.start_hidden_state = \
            Parameter(self.to_cuda(
                torch.randn(self.num_directions, 1, self.hidden_dim)
            ))
        self.start_cell_state = \
            Parameter(self.to_cuda(
                torch.randn(self.num_directions, 1, self.hidden_dim)
            ))

    def forward(self, batch, debug=0, dropout=None):
        """
        Run a biLSTM over the batch of docs, return their hidden states (padded).
        """
        b = len(batch.docs)
        docs_vectors = [
            torch.index_select(batch.embeddings_matrix, 1, doc).t()
            for doc in batch.docs
        ]
        # Assumes/requires that `batch.docs` is sorted by decreasing doc length.
        # This gets done in `chunked_sorted`.
        packed = pack_padded_sequence(torch.stack(docs_vectors, dim=1),
                                      lengths=list(batch.doc_lens))

        # run the biLSTM
        starts = (self.start_hidden_state.expand(self.num_directions, b,
                                                 self.hidden_dim).contiguous(),
                  self.start_cell_state.expand(self.num_directions, b,
                                               self.hidden_dim).contiguous())
        outs, _ = self.rnn(packed, starts)

        return outs
Beispiel #12
0
class Attention(nn.Module):
    """pointer network attention mechanism"""
    def __init__(self, input_dim, hidden_dim):
        super(Attention, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.d_linear = nn.Linear(input_dim, hidden_dim)
        self.e_linear = nn.Linear(input_dim, hidden_dim)
        self.V = Parameter(torch.FloatTensor(hidden_dim), requires_grad=True)

        self.inf = Parameter(
            torch.FloatTensor([float("-inf")]),
            requires_grad=False)  # aviod grad calculation certain steps

    def init_inf(self, size):
        self.inf = self.inf.expand(*size)

    def forward(self, hidden, context, mask):
        """
        :type hidden: torch.Tensor, [batch_size, hidden_dim]
        :type context: torch.Tensor, [batch_size, seq_len, hidden_dim]
        :type mask: torch.Tensor, [batch_size, seq_len]
        """

        batch_size = context.shape[0]
        seq_len = context.shape[1]
        di = self.d_linear(hidden).expand(-1, -1, seq_len)

        ei = self.e_linear(context)  # [batch, seq_len, hidden_dim]

        ui = torch.bmm(self.V.expand(batch_size, 1, -1),
                       F.tanh(di + ei).permute(0, 2, 1)).squeeze(
                           1)  # [batch, seq_len]

        # mask poster unit for softmax
        ui[mask] = self.inf[mask]

        alpha = F.softmax(ui)  # [batch, seq_len]

        attention_hidden_state = torch.bmm(ei.permute(0, 2, 1),
                                           alpha.unsqueeze(2)).squeeze(
                                               2)  # [batch, hidden_dim]

        return alpha, attention_hidden_state
Beispiel #13
0
class LSTMCell(nn.Module):
    def __init__(self,
                 input_size,
                 hidden_size,
                 bias=True,
                 layernorm=False,
                 dropoutr=0):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight = Parameter(
            torch.Tensor(input_size + hidden_size, 3 * hidden_size))

        self.bias = Parameter(torch.Tensor(3 * hidden_size)) if bias else None
        self.layernorm = nn.ModuleList(
            [nn.LayerNorm(hidden_size)
             for _ in range(3)]) if layernorm else None
        self.dropoutr = nn.Dropout(dropoutr) if dropoutr > 0 else None

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)
        if self.layernorm:
            for ln in self.layernorm:
                ln.reset_parameters()

    def forward(self, input, hx=None):
        # input: (batch, input_size)
        # hx: tuple of (batch, hidden_size), (batch, hidden_size)
        if hx is None:
            hx = input.new_zeros(input.size(0),
                                 self.hidden_size,
                                 requires_grad=False)
            hx = (hx, hx)
        h_0, c_0 = hx

        pre_gate = torch.mm(torch.cat((input, h_0), 1), self.weight)
        if self.bias: pre_gate += self.bias.expand(0)
        f, o, g = pre_gate.split(self.hidden_size, 1)
        if self.layernorm:
            f = self.layernorm[0](f)
            o = self.layernorm[1](o)
            g = self.layernorm[2](g)
        f = torch.sigmoid(f + 1.)  # _forget_bias
        i = 1. - f  # input and forget gates are coupled
        o = torch.sigmoid(o)
        g = torch.tanh(g)

        if self.dropoutr:
            g = self.dropoutr(g)  # recurrent dropout without memory loss

        c_1 = f * c_0 + i * g
        h_1 = o * torch.tanh(c_1)
        return h_1, c_1
Beispiel #14
0
class TiedLinear(torch.nn.Module):
    """
    TiedLinear is a linear layer with shared parameters for features between
    (output) classes that takes as input a tensor X with dimensions
        (batch size) X (output_dim) X (in_features)
        where:
            output_dim is the disired output dimension/# of classes
            in_features are the features with shared weights across the classes
    """

    def __init__(self, in_features, output_dim, bias=False):
        super(TiedLinear, self).__init__()
        self.in_features = in_features
        self.output_dim = output_dim
        self.weight = Parameter(torch.Tensor(1,in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(1,in_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
        # Broadcast parameters to the correct matrix dimensions for matrix
        # multiplication: this does NOT create new parameters: i.e. each
        # row of in_features of parameters are connected and will adjust
        # to the same values.
        self.W = self.weight.expand(output_dim, -1)
        if self.bias is not None:
            self.B = self.bias.expand(output_dim, -1)

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(0))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, X, index, mask):
        output = X.mul(self.W)
        if self.bias is not None:
            output += self.B
        output = output.sum(2)
        # Add our mask so that invalid domain classes for a given variable/VID
        # has a large negative value, resulting in a softmax probability
        # of de facto 0.
        output.index_add_(0, index, mask)
        return output
Beispiel #15
0
class Isotropy(Kernel):
    """
    Base class for a family of isotropic covariance kernels which are functions of the
    distance :math:`|x-z|/l`, where :math:`l` is the length-scale parameter.

    By default, the parameter ``lengthscale`` has size 1. To use the isotropic version
    (different lengthscale for each dimension), make sure that ``lengthscale`` has size
    equal to ``input_dim``.

    :param torch.Tensor lengthscale: Length-scale parameter of this kernel.
    """
    def __init__(self,
                 input_dim,
                 variance=None,
                 lengthscale=None,
                 active_dims=None):
        super(Isotropy, self).__init__(input_dim, active_dims)

        variance = torch.tensor(1.) if variance is None else variance
        self.variance = Parameter(variance)
        self.set_constraint("variance", constraints.positive)

        lengthscale = torch.tensor(1.) if lengthscale is None else lengthscale
        self.lengthscale = Parameter(lengthscale)
        self.set_constraint("lengthscale", constraints.positive)

    def _square_scaled_dist(self, X, Z=None):
        r"""
        Returns :math:`\|\frac{X-Z}{l}\|^2`.
        """
        if Z is None:
            Z = X
        X = self._slice_input(X)
        Z = self._slice_input(Z)
        if X.size(1) != Z.size(1):
            raise ValueError("Inputs must have the same number of features.")

        scaled_X = X / self.lengthscale
        scaled_Z = Z / self.lengthscale
        X2 = (scaled_X**2).sum(1, keepdim=True)
        Z2 = (scaled_Z**2).sum(1, keepdim=True)
        XZ = scaled_X.matmul(scaled_Z.t())
        r2 = X2 - 2 * XZ + Z2.t()
        return r2.clamp(min=0)

    def _scaled_dist(self, X, Z=None):
        r"""
        Returns :math:`\|\frac{X-Z}{l}\|`.
        """
        return _torch_sqrt(self._square_scaled_dist(X, Z))

    def _diag(self, X):
        """
        Calculates the diagonal part of covariance matrix on active features.
        """
        return self.variance.expand(X.size(0))
Beispiel #16
0
class ConstantMean(Mean):

    def __init__(self, constant=None ):
        super(ConstantMean, self).__init__()

        constant = torch.zeros(1) if constant is None else constant
        self.constant = Parameter(constant)

    def forward(self, X):
        return self.constant.expand(X.size(0))
Beispiel #17
0
class Constant(Kernel):
    r"""
    Implementation of Constant kernel:

        :math:`k(x, z) = \sigma^2.`
    """
    def __init__(self, input_dim, variance=None, active_dims=None):
        super(Constant, self).__init__(input_dim, active_dims)

        variance = torch.tensor(1.) if variance is None else variance
        self.variance = Parameter(variance)
        self.set_constraint("variance", constraints.positive)

    def forward(self, X, Z=None, diag=False):
        if diag:
            return self.variance.expand(X.size(0))

        if Z is None:
            Z = X
        return self.variance.expand(X.size(0), Z.size(0))
Beispiel #18
0
class WhiteNoise(Kernel):
    r"""
    Implementation of WhiteNoise kernel:

        :math:`k(x, z) = \sigma^2 \delta(x, z),`

    where :math:`\delta` is a Dirac delta function.
    """
    def __init__(self, input_dim, variance=None, active_dims=None):
        super(WhiteNoise, self).__init__(input_dim, active_dims)

        variance = torch.tensor(1.) if variance is None else variance
        self.variance = Parameter(variance)
        self.set_constraint("variance", constraints.positive)

    def forward(self, X, Z=None, diag=False):
        if diag:
            return self.variance.expand(X.size(0))

        if Z is None:
            return self.variance.expand(X.size(0)).diag()
        else:
            return X.data.new_zeros(X.size(0), Z.size(0))
Beispiel #19
0
class Periodic(Kernel):
    r"""
    Implementation of Periodic kernel:

        :math:`k(x,z)=\sigma^2\exp\left(-2\times\frac{\sin^2(\pi(x-z)/p)}{l^2}\right),`

    where :math:`p` is the ``period`` parameter.

    References:

    [1] `Introduction to Gaussian processes`,
    David J.C. MacKay

    :param torch.Tensor lengthscale: Length scale parameter of this kernel.
    :param torch.Tensor period: Period parameter of this kernel.
    """
    def __init__(self,
                 input_dim,
                 variance=None,
                 lengthscale=None,
                 period=None,
                 active_dims=None):
        super(Periodic, self).__init__(input_dim, active_dims)

        variance = torch.tensor(1.) if variance is None else variance
        self.variance = Parameter(variance)
        self.set_constraint("variance", constraints.positive)

        lengthscale = torch.tensor(1.) if lengthscale is None else lengthscale
        self.lengthscale = Parameter(lengthscale)
        self.set_constraint("lengthscale", constraints.positive)

        period = torch.tensor(1.) if period is None else period
        self.period = Parameter(period)
        self.set_constraint("period", constraints.positive)

    def forward(self, X, Z=None, diag=False):
        if diag:
            return self.variance.expand(X.size(0))

        if Z is None:
            Z = X
        X = self._slice_input(X)
        Z = self._slice_input(Z)
        if X.size(1) != Z.size(1):
            raise ValueError("Inputs must have the same number of features.")

        d = X.unsqueeze(1) - Z.unsqueeze(0)
        scaled_sin = torch.sin(math.pi * d / self.period) / self.lengthscale
        return self.variance * torch.exp(-2 * (scaled_sin**2).sum(-1))
Beispiel #20
0
class PointerNet(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, layers, dropout):
        super(PointerNet, self).__init__()

        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.layers = layers
        self.dropout = dropout

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.encoder = Encoder(embed_dim, hidden_dim, layers, dropout)

        self.decoder = Decoder(embed_dim, hidden_dim)

        self.first_decoder_input = Parameter(torch.FloatTensor(embed_dim),
                                             requires_grad=False)
        nn.init.uniform(self.first_decoder_input, -1, 1)

    def forward(self, inputs):
        """
        :type inputs: torch.Tensor, [batch, seq_len]
        """

        batch_size = inputs.shape[0]

        first_decoder_input = self.first_decoder_input.expand(
            batch_size, -1)  # [batch, embed_dim]

        enc_inputs = self.embedding(inputs)  # [batch, seq_len, embed_dim]

        first_encoder_hidden = self.encoder.init_hidden(
            inputs)  # Tuple(h0, c0)

        # ===============
        # encoder procedure
        # ===============
        enc_out, enc_hidden = self.encoder(
            enc_inputs,
            first_encoder_hidden)  # [batch, seq_len, layers * hidden_dim]

        # ===============
        # decoder procedure
        # ===============
        first_decoder_hidden = (enc_hidden[0][-1], enc_hidden[1][-1])

        (outputs,
         pointers), dec_hidden = self.decoder(enc_inputs, first_decoder_input,
                                              first_decoder_hidden, enc_out)

        return outputs, pointers
Beispiel #21
0
class LcpStructuredNet(torch.nn.Module):
    def __init__(self, warm_start, include_G_weights):
        super(LcpStructuredNet, self).__init__()
        self.include_G_weights = include_G_weights

        self.f = torch.nn.Linear(2, 3, bias=False)
        self.f_bias = torch.nn.Parameter(torch.ones(3))
        self.G = torch.nn.Linear(2, 9, bias=False)
        self.G_bias = torch.nn.Parameter(torch.ones(9))

        # Correct dynamics solution
        if warm_start:
            self.G.weight.data.fill_(0)
            self.G.weight.data = self.add_noise(self.G.weight.data)
            self.G_bias = Parameter(self.add_noise(
                    torch.tensor([1, -1, 1,
                                 -1,  1, 1,
                                 -1, -1, 0]).float()))

            self.f.weight = Parameter(self.add_noise(
                    torch.tensor([[1,  1],
                                 [-1, -1],
                                  [0,  0]]).float()))
            self.f_bias = Parameter(self.add_noise(
                    torch.tensor([0, 0, 1]).float()))
        else:
            torch.nn.init.xavier_uniform_(self.f.weight)
            torch.nn.init.xavier_uniform_(self.G.weight)

        if not include_G_weights:
            self.G.weight.data.fill_(0)
    
    def add_noise(self, tensor):
        m = torch.distributions.normal.Normal(0, 0.1)
        return tensor + m.sample(tensor.shape).float()

    def forward(self, states):
        lambdas = states[:, 2:5]
        xus = states[:, 0:2]
        fxu = self.f(xus) + self.f_bias
        
        if self.include_G_weights:
            Gxu = (self.G(xus) + self.G_bias).view(-1, 3, 3)
        else:
            Gxu = (self.G_bias.expand(states.shape[0], 9)).view(-1, 3, 3) 
        
        lcp_slack = fxu + torch.bmm(Gxu,lambdas.unsqueeze(2)).squeeze(2)
        
        return lcp_slack
Beispiel #22
0
class GumbelMF(nn.Module):
    def __init__(self,
                 n_users,
                 n_items,
                 n_dim,
                 n_obs,
                 lub=1.,
                 lib=1.,
                 luv=1.,
                 liv=1.,
                 tau=0.8,
                 loss=nn.MSELoss):
        super(GumbelMF, self).__init__()
        self.embed_user = BiasedEmbedding(n_users, n_dim, lb=lub, lv=luv)
        self.embed_item = BiasedEmbedding(n_items, n_dim, lb=lib, lv=liv)
        self.glob_bias = Parameter(torch.FloatTensor([0.01]))
        self.n_obs = n_obs
        self.lossf = loss()
        self.tau = tau

    def forward(self, u, i):
        u, i = u.squeeze(), i.squeeze()
        bias = self.glob_bias.expand(len(u), 1).squeeze()
        bu, lu = self.embed_user(u)
        bi, li = self.embed_item(i)
        if self.training:
            du, di = gumbel_softmax_correlated(lu, li)
            # du = gumbel_softmax(lu, self.tau)
            # di = gumbel_softmax(li, self.tau)
        else:
            du = F.softmax(lu)
            di = F.softmax(li)
        intx = hellinger(du, di)
        logodds = (bias + bi + bu + intx).squeeze()
        return logodds

    def loss(self, prediction, target):
        # average likelihood loss per example
        ex_llh = self.lossf(prediction, target)
        # regularization penalty summed over whole model
        epoch_reg = (self.embed_user.prior() + self.embed_item.prior())
        # penalty should be computer for a single example
        ex_reg = epoch_reg * 1.0 / self.n_obs
        return ex_llh + ex_reg
Beispiel #23
0
class MFPoincare(nn.Module):
    def __init__(self, n_users, n_items, n_dim, n_obs):
        super(MFPoincare, self).__init__()
        self.embed_user = BiasedEmbedding(n_users, n_dim)
        self.embed_item = BiasedEmbedding(n_items, n_dim)
        self.glob_bias = Parameter(torch.Tensor(1, 1))
        self.n_obs = n_obs

    def forward(self, u, i):
        bias = self.glob_bias.expand(len(u), 1).squeeze()
        bu, vu = self.embed_user(u)
        bi, vi = self.embed_item(i)
        dist = poincare_distance(vu, vi)
        logodds = bias + bi + bu + dist
        return logodds

    def loss(self, prediction, target):
        # Don't know how to regularize poincare space yet!
        llh = F.binary_cross_entropy_with_logits(prediction, target)
        return llh
Beispiel #24
0
class P_AnalysisDict(nn.Module):
    def __init__(self, col_synthesis_dict, col_sample, row_sample, num_sample,
                 device):
        super(P_AnalysisDict, self).__init__()
        self.col_synthesis_dict = col_synthesis_dict
        self.num_views = col_sample
        self.num_sample = num_sample
        self.row_sample = row_sample
        self.device = device
        self.Analysis_Dict = Parameter(
            pt.Tensor(col_synthesis_dict, row_sample))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.Analysis_Dict.size(0))
        self.Analysis_Dict.data.uniform_(-stdv, stdv)

    def forward(self, X):  #get the rebuilt value of sparse code
        P = self.Analysis_Dict.expand(X.size(0), self.col_synthesis_dict,
                                      self.row_sample)
        out = pt.bmm(P, X)
        return out
Beispiel #25
0
class P_SynthesisDict(nn.Module):
    def __init__(self, col_SynthesisDict, col_sample, row_sample, num_sample,
                 device):
        super(P_SynthesisDict, self).__init__()
        self.col_synthesis_dict = col_SynthesisDict
        self.num_views = col_sample
        self.num_sample = num_sample
        self.row_sample = row_sample
        self.device = device
        self.Synthesis_Dict = Parameter(
            pt.Tensor(row_sample, col_SynthesisDict))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.Synthesis_Dict.size(0))  # axis=0计算每一列的标准差
        self.Synthesis_Dict.data.uniform_(-stdv, stdv)

    def forward(self, Sparse_Code):  # obtain the rebuilt value of sample_x
        B = self.Synthesis_Dict.expand(self.num_sample, self.row_sample,
                                       self.col_synthesis_dict)
        out = pt.bmm(B, Sparse_Code)
        return out
Beispiel #26
0
class MF(nn.Module):
    def __init__(self,
                 n_users,
                 n_items,
                 n_dim,
                 n_obs,
                 lub=1.,
                 lib=1.,
                 luv=1.,
                 liv=1.,
                 loss=nn.MSELoss):
        super(MF, self).__init__()
        self.embed_user = BiasedEmbedding(n_users, n_dim, lb=lub, lv=luv)
        self.embed_item = BiasedEmbedding(n_items, n_dim, lb=lib, lv=liv)
        self.glob_bias = Parameter(torch.FloatTensor([0.01]))
        self.n_obs = n_obs
        self.lossf = loss()

    def forward(self, u, i):
        u, i = u.squeeze(), i.squeeze()
        bias = self.glob_bias.expand(len(u), 1).squeeze()
        bu, vu = self.embed_user(u)
        bi, vi = self.embed_item(i)
        intx = (vu * vi).sum(dim=1)
        logodds = (bias + bi + bu + intx).squeeze()
        return logodds

    def loss(self, prediction, target):
        # average likelihood loss per example
        ex_llh = self.lossf(prediction, target)
        if self.training:
            # regularization penalty summed over whole model
            epoch_reg = (self.embed_user.prior() + self.embed_item.prior())
            # penalty should be computer for a single example
            ex_reg = epoch_reg * 1.0 / self.n_obs
        else:
            ex_reg = 0.0
        return ex_llh + ex_reg
Beispiel #27
0
class MFPoly2(nn.Module):
    def __init__(self, n_users, n_items, n_dim, n_obs, lub=1.,
                 lib=1., luv=1., liv=1., loss=nn.MSELoss):
        super(MFPoly2, self).__init__()
        self.embed_user = BiasedEmbedding(n_users, n_dim, lb=lub, lv=luv)
        self.embed_item = BiasedEmbedding(n_items, n_dim, lb=lib, lv=liv)
        self.glob_bias = Parameter(torch.FloatTensor([0.01]))
        # maps from a scalar (dim=2) of frame and frame^2
        # effectively fitting a quadratic polynomial to the frame number
        # to a scalar log odds (dim=1)
        self.poly = nn.Linear(2, 1)
        self.n_obs = n_obs
        self.lossf = loss()

    def forward(self, u, i, f):
        u, i = u.squeeze(), i.squeeze()
        bias = self.glob_bias.expand(len(u), 1).squeeze()
        bu, vu = self.embed_user(u)
        bi, vi = self.embed_item(i)
        intx = (vu * vi).sum(dim=1)
        frame = [f.unsqueeze(0), f.unsqueeze(0)**2.0]
        effect = self.poly(torch.t(torch.log(torch.cat(frame)))).squeeze()
        logodds = (bias + bi + bu + intx + effect).squeeze()
        return logodds

    def loss(self, prediction, target):
        # average likelihood loss per example
        ex_llh = self.lossf(prediction, target)
        if self.training:
            # regularization penalty summed over whole model
            epoch_reg = (self.embed_user.prior() + self.embed_item.prior())
            # penalty should be computer for a single example
            ex_reg = epoch_reg * 1.0 / self.n_obs
        else:
            ex_reg = 0.0
        return ex_llh + ex_reg
Beispiel #28
0
class Decoder(nn.Module):
    """decoder layer
        * first decoder hidden
        * first decoder input
        * each step encoder inputs
        * each step encoder outputs
    received for conducting rnn layers and pointer network attention
    then give each decoder step pointer index and corresponding attention alphas
    """
    def __init__(self, embed_dim, hidden_dim):
        super(Decoder, self).__init__()
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim

        # lstm layer
        self.input2hidden = nn.Linear(embed_dim, hidden_dim * 4)
        self.hidden2hidden = nn.Linear(hidden_dim, hidden_dim * 4)
        self.hidden2out = nn.Linear(2 * hidden_dim, hidden_dim)

        # pointer attention layer
        self.attention = Attention(hidden_dim, hidden_dim)

        self.mask = Parameter(torch.ones(1), requires_grad=False)
        self.runner = Parameter(torch.zeros(1),
                                requires_grad=False)  # record each step index

    def forward(self, enc_inputs, dec_inputs, hidden, context):
        """
        :type enc_inputs: torch.Tensor, each encoder step inputs
        :type dec_inputs: torch.Tensor, first decoder inputs fed
        :type hidden: torch.Tensor, first decoder hidden fed
        :type context: torch.Tensor, each encoder step outputs
        """

        batch_size = enc_inputs.shape[0]
        seq_len = enc_inputs.shape[1]

        mask = self.mask.expand(batch_size, seq_len)
        self.attention.init_inf(mask.size())

        runner = self.runner.expand(batch_size, seq_len)
        for i in range(seq_len):
            runner.index_fill_(1, torch.tensor([i]), i)

        outputs = []
        pointers = []

        # ==============
        # decoder for each step
        # ==============
        for _ in range(seq_len):
            g_o, (c_t, h_t) = self.lstm_step(dec_inputs, hidden)
            h_t, alpha = self.attention_step(h_t, context, mask)

            hidden = (h_t, c_t)

            masked_alpha = alpha * mask  # [batch_size, seq_len]

            max_prob, max_index = masked_alpha.max(1)

            # update mask
            seen_pointer = (runner == max_index.unsqueeze(1).expand(
                -1, alpha.shape[1])).float()
            mask = mask * (1 - seen_pointer)

            # update dec_inputs from mask
            embedding_mask = seen_pointer.unsqueeze(2).expand(
                -1, -1, self.embed_dim)
            dec_inputs = enc_inputs[embedding_mask].view(
                batch_size, self.embed_dim)

            outputs.append(alpha.unsqueeze(0))
            pointers.append(max_index.unsqueeze(1))

        outputs = torch.cat(outputs,
                            dim=0).permute(1, 0, 2)  # [d_len, batch, seq_len]
        pointers = torch.cat(pointers, dim=0).permute(1, 0,
                                                      2)  # [d_len, batch, 1]

        return (outputs, pointers), hidden

    def lstm_step(self, x, hidden):
        """single lstm cell forward"""

        h, c = hidden

        gates_cells = self.input2hidden(x) + self.hidden2hidden(hidden)
        g_i, g_f, g_c, g_o = gates_cells.chunk(4, 1)

        g_i = F.sigmoid(g_i)
        g_f = F.sigmoid(g_f)
        g_c = F.tanh(g_c)
        g_o = F.sigmoid(g_o)

        c_t = (g_f * c) + (g_i * g_c)
        h_t = g_o * F.tanh(c_t)

        return g_o, (c_t, h_t)

    def attention_step(self, h_t, context, mask):
        """single pointer network attention forward"""

        alpha, attention_hidden_state = self.attention(h_t, context,
                                                       torch.eq(mask, 0))
        attention_hidden_state = F.tanh(
            self.hidden2out(torch.cat((attention_hidden_state, h_t), 1)))
        # seems a fusion layer, but hidden_state from attention layer could used directly

        return attention_hidden_state, alpha
Beispiel #29
0
class ViT(Module):
    """Vision transformer as per https://arxiv.org/abs/2010.11929."""
    @staticmethod
    def get_params():
        return {
            "image_size": cfg.TRAIN.IM_SIZE,
            "patch_size": cfg.VIT.PATCH_SIZE,
            "stem_type": cfg.VIT.STEM_TYPE,
            "c_stem_kernels": cfg.VIT.C_STEM_KERNELS,
            "c_stem_strides": cfg.VIT.C_STEM_STRIDES,
            "c_stem_dims": cfg.VIT.C_STEM_DIMS,
            "n_layers": cfg.VIT.NUM_LAYERS,
            "n_heads": cfg.VIT.NUM_HEADS,
            "hidden_d": cfg.VIT.HIDDEN_DIM,
            "mlp_d": cfg.VIT.MLP_DIM,
            "cls_type": cfg.VIT.CLASSIFIER_TYPE,
            "num_classes": cfg.MODEL.NUM_CLASSES,
        }

    @staticmethod
    def check_params(params):
        p = params
        err_str = "Input shape indivisible by patch size"
        assert p["image_size"] % p["patch_size"] == 0, err_str
        assert p["stem_type"] in ["patchify", "conv"], "Unexpected stem type"
        assert p["cls_type"] in ["token",
                                 "pooled"], "Unexpected classifier mode"
        if p["stem_type"] == "conv":
            err_str = "Conv stem layers mismatch"
            assert len(p["c_stem_dims"]) == len(p["c_stem_strides"]), err_str
            assert len(p["c_stem_strides"]) == len(
                p["c_stem_kernels"]), err_str
            err_str = "Stem strides unequal to patch size"
            assert p["patch_size"] == np.prod(p["c_stem_strides"]), err_str
            err_str = "Stem output dim unequal to hidden dim"
            assert p["c_stem_dims"][-1] == p["hidden_d"], err_str

    def __init__(self, params=None):
        super(ViT, self).__init__()
        p = ViT.get_params() if not params else params
        ViT.check_params(p)
        if p["stem_type"] == "patchify":
            self.stem = ViTStemPatchify(3, p["hidden_d"], p["patch_size"])
        elif p["stem_type"] == "conv":
            ks, ws, ss = p["c_stem_kernels"], p["c_stem_dims"], p[
                "c_stem_strides"]
            self.stem = ViTStemConv(3, ks, ws, ss)
        seq_len = (p["image_size"] // cfg.VIT.PATCH_SIZE)**2
        if p["cls_type"] == "token":
            self.class_token = Parameter(torch.zeros(1, 1, p["hidden_d"]))
            seq_len += 1
        else:
            self.class_token = None
        self.pos_embedding = Parameter(torch.zeros(seq_len, 1, p["hidden_d"]))
        self.encoder = ViTEncoder(p["n_layers"], p["hidden_d"], p["n_heads"],
                                  p["mlp_d"])
        self.head = ViTHead(p["hidden_d"], p["num_classes"])
        init_weights_vit(self)

    def forward(self, x):
        # (n, c, h, w) -> (n, hidden_d, n_h, n_w)
        x = self.stem(x)
        # (n, hidden_d, n_h, n_w) -> (n, hidden_d, (n_h * n_w))
        x = x.reshape(x.size(0), x.size(1), -1)
        # (n, hidden_d, (n_h * n_w)) -> ((n_h * n_w), n, hidden_d)
        x = x.permute(2, 0, 1)
        if self.class_token is not None:
            # Expand the class token to the full batch
            class_token = self.class_token.expand(-1, x.size(1), -1)
            x = torch.cat([class_token, x], dim=0)
        x = x + self.pos_embedding
        x = self.encoder(x)
        # `token` or `pooled` features for classification
        x = x[0, :, :] if self.class_token is not None else x.mean(dim=0)
        return self.head(x)

    @staticmethod
    def complexity(cx, params=None):
        """Computes model complexity. If you alter the model, make sure to update."""
        p = ViT.get_params() if not params else params
        ViT.check_params(p)
        if p["stem_type"] == "patchify":
            cx = ViTStemPatchify.complexity(cx, 3, p["hidden_d"],
                                            p["patch_size"])
        elif p["stem_type"] == "conv":
            ks, ws, ss = p["c_stem_kernels"], p["c_stem_dims"], p[
                "c_stem_strides"]
            cx = ViTStemConv.complexity(cx, 3, ks, ws, ss)
        seq_len = (p["image_size"] // cfg.VIT.PATCH_SIZE)**2
        if p["cls_type"] == "token":
            seq_len += 1
            cx["params"] += p["hidden_d"]
        # Params of position embeddings
        cx["params"] += seq_len * p["hidden_d"]
        cx = ViTEncoder.complexity(cx, p["n_layers"], p["hidden_d"],
                                   p["n_heads"], p["mlp_d"], seq_len)
        cx = ViTHead.complexity(cx, p["hidden_d"], p["num_classes"])
        return cx
Beispiel #30
0
class dmn(Parameterized):
    # this is equivalent to
    # pyro.contrib.gp.models.model.GPModel
    r"""
    Base class for Dynamic Networks using the Latent Space representation.

    Each node :math:`i` has four Gaussian Processes associated.
    1. Location in the latent social space for link propensity (probability of being connected)
    2. Location in the latent social space for edge weight (size of the connection)
    3. Sociability of the node
    3. Popularity of the node

    There are two ways to train the DMN model:

    + Using an MCMC algorithm
    + Using a variational inference on the pair :meth:`model`, :meth:`guide`:

    """
    def __init__(self,
                 edgelist,
                 H_dim=3,
                 X=None,
                 weighted=True,
                 directed=True,
                 coord=False,
                 socpop=True,
                 jitter=1e-6,
                 whiten=False):

        super(dmn, self).__init__()

        Y, [all_nodes, Y_time,
            all_layers] = pydmn.util.edgelist_to_tensor(edgelist=edgelist)
        self.all_layers = all_layers

        self.set_data(Y=Y, Y_time=Y_time, H_dim=H_dim, X=X)
        self.whiten = whiten

        self.jitter = jitter

        self.weighted = weighted
        self.directed = directed
        self.coord = coord
        self.socpop = socpop

        # Dimensions of objects change if the network is weighted and/or directed
        self.lw_dim = (2 if weighted else 1)
        self.sr_dim = (2 if directed else 1)

        self.kernel = Parameterized()  # kernels modules will be added here
        self.gp = Parameterized()  # GP modules will be added here

        if self.weighted:
            sigma_Y = torch.ones(self.K_net)
            self.sigma_Y = Parameter(sigma_Y)
            self.set_constraint("sigma_Y", constraints.positive)

        ### Latent locations ###
        if coord:
            # Kernels for GP of latent locations #
            # Assume that the kernel is the same for all agents, i.e. the dynamics in the latent space is the same #
            self.kernel.coord = Parameterized()
            self.kernel.coord.link = gp.kernels.RBF(input_dim=1)
            if self.weighted:
                self.kernel.coord.weight = gp.kernels.RBF(input_dim=1)

            ## Dynamic latent locations ##
            self.gp.coord = Parameterized()

            # GP location (without mean fun) #
            # dimensions: (V_net, H_dim, sr_dim, lw_dim, T_net) = (1 per agent, 1 per lat dim, send & rec, link & weight, Time )
            gp_coord_loc = torch.randn(self.V_net, self.H_dim, self.sr_dim,
                                       self.lw_dim, self.T_net)
            self.gp.coord.loc = Parameter(gp_coord_loc)

            # Mean function: constant #
            # dimensions: (V_net, H_dim, sr_dim, lw_dim) = (1 per agent, 1 per lat dim, send & rec, link & weight)
            gp_coord_loc_mean = torch.randn(self.V_net, self.H_dim,
                                            self.sr_dim, self.lw_dim)
            self.gp.coord.loc_mean = Parameter(gp_coord_loc_mean)

            # Lower Cholesky of the Covariance matrix of latent coordinates #
            # dimensions: (V_net, H_dim, sr_dim, lw_dim, T_net, T_net) = (1 per agent, 1 per lat dim, send & rec, link & weight, Time, Time )
            gp_coord_cov_tril_unconst = torch.diag_embed(
                torch.ones(self.V_net, self.H_dim, self.sr_dim, self.lw_dim,
                           self.T_net))
            self.gp.coord.cov_tril_unconst = Parameter(
                gp_coord_cov_tril_unconst)
            self.gp.coord.cov_tril = torch.stack([
                torch.stack([
                    torch.stack([
                        torch.stack([
                            transform_to(constraints.lower_cholesky)(
                                self.gp.coord.cov_tril_unconst[v, h, sr_i,
                                                               lw_i, :, :])
                            for lw_i in range(self.lw_dim)
                        ]) for sr_i in range(self.sr_dim)
                    ]) for h in range(self.H_dim)
                ]) for v in range(self.V_net)
            ])
        else:
            self.kernel.coord = None
            self.gp.coord = None

        ### Sociability and Popularity ###
        if socpop:
            # Kernels for GP of Sociability and Popularity params #
            # Assume that the kernel is the same for sociability and popularity #
            self.kernel.socpop = Parameterized()
            self.kernel.socpop.link = gp.kernels.RBF(input_dim=1)
            if self.weighted:
                self.kernel.socpop.weight = gp.kernels.RBF(input_dim=1)

            ## Dynamic Sociability and Popularity ##
            self.gp.socpop = Parameterized()

            # dimensions: (V_net, 2, lw_dim, T_net) = (1 per agent, soc & pop, link & weight, Time )
            gp_socpop_loc = torch.randn(self.V_net, 2, self.lw_dim, self.T_net)
            self.gp.socpop.loc = Parameter(gp_socpop_loc)

            # Mean function: constant #
            gp_socpop_loc_mean = torch.randn(self.V_net, 2, self.lw_dim)
            self.gp.socpop.loc_mean = Parameter(gp_socpop_loc_mean)

            # GP.socpop.Cov_tril: Lower Cholesky of the Covariance matrix of latent socpopinates
            # dimensions: (V_net, 2, lw_dim, T_net, T_net) = (1 per agent, soc & pop, link & weight, Time, Time )
            gp_socpop_cov_tril_unconst = torch.diag_embed(
                torch.ones(self.V_net, 2, self.lw_dim, self.T_net))
            self.gp.socpop.cov_tril_unconst = Parameter(
                gp_socpop_cov_tril_unconst)
            self.gp.socpop.cov_tril = torch.stack([
                torch.stack([
                    torch.stack([
                        transform_to(constraints.lower_cholesky)(
                            self.gp.socpop.cov_tril_unconst[v, sp_i,
                                                            lw_i, :, :])
                        for lw_i in range(self.lw_dim)
                    ]) for sp_i in range(2)
                ]) for v in range(self.V_net)
            ])
        else:
            self.kernel.socpop = None
            self.gp.socpop = None

    # @autoname.scope(prefix="DMN") # generates error
    def model(self):
        self.set_mode("model")

        # Sample the coordinates
        # dimensions: (V_net, H_dim, sr_dim, lw_dim, T_net) = (1 per agent, 1 per lat dim, send & rec, link & weight, Time )
        if self.coord:

            # Calculates lower cholesky for all soc, pop
            Kff = [
                eval('self.kernel.coord.' + ['link', 'weight'][lw_i])(
                    self.Y_time).contiguous() for lw_i in range(self.lw_dim)
            ]
            for lw_i in range(self.lw_dim):
                Kff[lw_i].view(
                    -1)[::self.T_net +
                        1] += self.jitter  # add jitter to the diagonal
            Lff_coord = [Kff[lw_i].cholesky() for lw_i in range(self.lw_dim)]

            # Gaussian process for the sociability and popularity #
            gp_coord = torch.stack([
                torch.stack([
                    torch.stack([
                        torch.stack([
                            pydmn.util.GP_sample(
                                name=f'f_coord_v{v}_h{h}_sr{sr_i}_lw{lw_i}',
                                X=self.Y_time,
                                f_loc=self.gp.coord.loc[v, h, sr_i, lw_i, :],
                                f_loc_mean=self.gp.coord.loc_mean[v, h, sr_i,
                                                                  lw_i],
                                f_scale_tril=self.gp.coord.cov_tril[
                                    v, h, sr_i, lw_i, :, :],
                                Lff=Lff_coord[lw_i],
                                whiten=self.whiten)
                            for lw_i in range(self.lw_dim)
                        ]) for sr_i in range(self.sr_dim)
                    ]) for h in range(self.H_dim)
                ]) for v in range(self.V_net)
            ])

        # Sample the Sociability and Popularity ##
        # dimensions: (V_net, 2, lw_dim, T_net) = (1 per agent, soc & pop, link & weight, Time )
        if self.socpop:

            # Calculates lower cholesky for all soc, pop
            Kff = [
                eval('self.kernel.socpop.' + ['link', 'weight'][lw_i])(
                    self.Y_time).contiguous() for lw_i in range(self.lw_dim)
            ]
            for lw_i in range(self.lw_dim):
                Kff[lw_i].view(
                    -1)[::self.T_net +
                        1] += self.jitter  # add jitter to the diagonal
            Lff_socpop = [Kff[lw_i].cholesky() for lw_i in range(self.lw_dim)]

            # Gaussian process for the sociability and popularity #
            gp_socpop = torch.stack([
                torch.stack([
                    torch.stack([
                        pydmn.util.GP_sample(
                            name=
                            f'f_socpop_v{v}_{["soc","pop"][sp_i]}_lw{lw_i}',
                            X=self.Y_time,
                            f_loc=self.gp.socpop.loc[v, sp_i, lw_i, :],
                            f_loc_mean=self.gp.socpop.loc_mean[v, sp_i, lw_i],
                            f_scale_tril=self.gp.socpop.cov_tril[v, sp_i,
                                                                 lw_i, :, :],
                            Lff=Lff_socpop[lw_i],
                            whiten=self.whiten) for lw_i in range(self.lw_dim)
                    ]) for sp_i in range(2)
                ]) for v in range(self.V_net)
            ])

        ### Calculate Linear Predictors ###
        Y_linpred = torch.zeros(self.V_net, self.V_net, self.T_net,
                                self.lw_dim)

        # identifies diagonal elements, which will not be considered when computing the likelihood
        Y_diag = torch.diag_embed(torch.ones(
            self.T_net, self.V_net)).transpose(0, 2).flatten()

        if self.coord:
            # gp_coord.shape # gp_coord[v,h,sr_i,lw_i,t]
            send = gp_coord[:, :, 0, :, :].expand(self.V_net, self.V_net,
                                                  self.H_dim, self.lw_dim,
                                                  self.T_net).transpose(0, 1)
            if self.directed:
                receive = gp_coord[:, :,
                                   1, :, :].expand(self.V_net, self.V_net,
                                                   self.H_dim, self.lw_dim,
                                                   self.T_net)
            else:
                receive = gp_coord[:, :,
                                   0, :, :].expand(self.V_net, self.V_net,
                                                   self.H_dim, self.lw_dim,
                                                   self.T_net)
            Y_linpred += ((send - receive)**2).sum(dim=2).rsqrt().transpose(
                2, 3)

        if self.socpop:
            soc = gp_socpop[:, 0, :, :].expand(self.V_net, self.V_net,
                                               self.lw_dim,
                                               self.T_net).transpose(0, 1)
            pop = gp_socpop[:, 1, :, :].expand(self.V_net, self.V_net,
                                               self.lw_dim, self.T_net)
            Y_linpred += (soc + pop).transpose(2, 3)

        ### Sampling 0-1 links ###
        Y_probs = torch.sigmoid(Y_linpred[:, :, :, 0].flatten())[Y_diag == 0]
        Y_dist_link = dist.Bernoulli(Y_probs)
        Y_dist_link = Y_dist_link.expand_by(
            Y_probs.shape[:-Y_probs.dim()]).to_event(Y_probs.dim())
        return pyro.sample("Y_link",
                           Y_dist_link,
                           obs=(self.Y_link.flatten())[Y_diag == 0])

        ### Sampling weights ###
        if self.weighted:
            Y_weighted_id = (Y_diag == 0) & (self.Y.flatten() != 0)

            Y_loc = (Y_linpred[:, :, :, 1].flatten())[Y_weighted_id]
            Y_scale = self.sigma_Y.expand(self.V_net, self.V_net,
                                          self.T_net).flatten()[Y_weighted_id]

            Y_dist = dist.Normal(Y_loc, Y_scale)
            Y_dist = Y_dist.expand_by(Y_loc.shape[:-Y_loc.dim()]).to_event(
                Y_loc.dim())

            pyro.sample("Y", Y_dist, obs=(self.Y.flatten())[Y_weighted_id])

    # @autoname.scope(prefix="DMN") # generates error
    def guide(self):
        self.set_mode("guide")

        if self.coord:
            for v in range(self.V_net):
                for h in range(self.H_dim):
                    for sr_i in range(self.sr_dim):
                        for lw_i in range(self.lw_dim):
                            pyro.sample(
                                f'f_coord_v{v}_h{h}_sr{sr_i}_lw{lw_i}',
                                dist.MultivariateNormal(
                                    self.gp.coord.loc[v, h, sr_i, lw_i, :],
                                    scale_tril=self.gp.coord.cov_tril[
                                        v, h, sr_i, lw_i, :, :]).to_event(
                                            self.gp.coord.loc[v, h, sr_i,
                                                              lw_i, :].dim() -
                                            1))

        if self.socpop:
            for v in range(self.V_net):
                for sp_i in range(2):
                    for lw_i in range(self.lw_dim):
                        pyro.sample(
                            f'f_socpop_v{v}_{["soc","pop"][sp_i]}_lw{lw_i}',
                            dist.MultivariateNormal(
                                self.gp.socpop.loc[v, sp_i, lw_i, :],
                                scale_tril=self.gp.socpop.cov_tril[v, sp_i,
                                                                   lw_i, :, :]
                            ).to_event(self.gp.socpop.loc[v, sp_i,
                                                          lw_i, :].dim() - 1))

    def forward(self, Y_time_new, num_particles=30):
        r"""
        Computes something
        """
        self.set_mode("guide")
        # Y_time_new=torch.arange(0,30,0.25)

        ## Generate coordinates for new times##
        if self.coord:

            Lff_coord = []
            Kfs_coord = []
            Kss_coord = []
            for lw_i in range(self.lw_dim):
                Kff = eval('self.kernel.coord.' + ['link', 'weight'][lw_i])(
                    self.Y_time).contiguous()
                Kff.view(-1)[::self.T_net +
                             1] += self.jitter  # add jitter to the diagonal
                Lff_coord.append(Kff.cholesky())

                Kfs_coord.append(
                    eval('self.kernel.coord.' + ['link', 'weight'][lw_i])(
                        self.Y_time, Y_time_new))

                Kss_coord.append(
                    eval('self.kernel.coord.' +
                         ['link', 'weight'][lw_i])(Y_time_new).contiguous())
                Kss_coord[lw_i].view(
                    -1)[::Kss_coord[lw_i].shape[0] +
                        1] += self.jitter  # add jitter to the diagonal

            gp_coord_loc_and_cov_new = torch.stack([
                torch.stack([
                    torch.stack([
                        torch.stack([
                            torch.stack(
                                pydmn.util.conditional(
                                    Xnew=Y_time_new,
                                    X=self.Y_time,
                                    kernel=None,
                                    f_loc=self.gp.coord.loc[v, h, sr_i,
                                                            lw_i, :],
                                    f_scale_tril=self.gp.coord.cov_tril[
                                        v, h, sr_i, lw_i, :, :],
                                    Lff=Lff_coord[lw_i],
                                    full_cov=True,
                                    whiten=self.whiten,
                                    jitter=self.jitter,
                                    Kfs=Kfs_coord[lw_i],
                                    Kss=Kss_coord[lw_i]))
                            for lw_i in range(self.lw_dim)
                        ]) for sr_i in range(self.sr_dim)
                    ]) for h in range(self.H_dim)
                ]) for v in range(self.V_net)
            ])

            gp_coord_loc_new = gp_coord_loc_and_cov_new[:, :, :, :, 0, 0, :]
            gp_coord_cov_new = gp_coord_loc_and_cov_new[:, :, :, :, 1, :, :]

            gp_coord_new_sample = torch.stack([
                torch.stack([
                    torch.stack([
                        torch.stack([
                            dist.MultivariateNormal(
                                gp_coord_loc_new[v, h, sr_i, lw_i, :] +
                                self.gp.coord.loc_mean[v, h, sr_i, lw_i],
                                gp_coord_cov_new[v, h, sr_i,
                                                 lw_i, :, :]).expand([
                                                     num_particles
                                                 ]).sample().transpose(0, 1)
                            for lw_i in range(self.lw_dim)
                        ]) for sr_i in range(self.sr_dim)
                    ]) for h in range(self.H_dim)
                ]) for v in range(self.V_net)
            ])

        ## Generate sociability and popularity for new times ##
        if self.socpop:

            Lff_socpop = []
            Kfs_socpop = []
            Kss_socpop = []
            for lw_i in range(self.lw_dim):
                Kff = eval('self.kernel.socpop.' + ['link', 'weight'][lw_i])(
                    self.Y_time).contiguous()
                Kff.view(-1)[::self.T_net +
                             1] += self.jitter  # add jitter to the diagonal
                Lff_socpop.append(Kff.cholesky())

                Kfs_socpop.append(
                    eval('self.kernel.socpop.' + ['link', 'weight'][lw_i])(
                        self.Y_time, Y_time_new))

                Kss_socpop.append(
                    eval('self.kernel.socpop.' +
                         ['link', 'weight'][lw_i])(Y_time_new).contiguous())
                Kss_socpop[lw_i].view(
                    -1)[::Kss_socpop[lw_i].shape[0] +
                        1] += self.jitter  # add jitter to the diagonal

            gp_socpop_loc_and_cov_new = torch.stack([
                torch.stack([
                    torch.stack([
                        torch.stack(
                            pydmn.util.conditional(
                                Xnew=Y_time_new,
                                X=self.Y_time,
                                kernel=None,
                                f_loc=self.gp.socpop.loc[v, sp_i, lw_i, :],
                                f_scale_tril=self.gp.socpop.cov_tril[
                                    v, sp_i, lw_i, :, :],
                                Lff=Lff_socpop[lw_i],
                                full_cov=True,
                                whiten=self.whiten,
                                jitter=self.jitter,
                                Kfs=Kfs_socpop[lw_i],
                                Kss=Kss_socpop[lw_i]))
                        for lw_i in range(self.lw_dim)
                    ]) for sp_i in range(2)
                ]) for v in range(self.V_net)
            ])

            gp_socpop_loc_new = gp_socpop_loc_and_cov_new[:, :, :, 0, 0, :]
            gp_socpop_cov_new = gp_socpop_loc_and_cov_new[:, :, :, 1, :, :]

            gp_socpop_new_sample = torch.stack([
                torch.stack([
                    torch.stack([
                        dist.MultivariateNormal(
                            gp_socpop_loc_new[v, sp_i, lw_i, :] +
                            self.gp.socpop.loc_mean[v, sp_i, lw_i],
                            gp_socpop_cov_new[v, sp_i, lw_i, :, :]).expand(
                                [num_particles]).sample().transpose(0, 1)
                        for lw_i in range(self.lw_dim)
                    ]) for sp_i in range(2)
                ]) for v in range(self.V_net)
            ])

        ### Calculate Linear Predictors ###
        Y_linpred_new_sample = torch.zeros(self.V_net, self.V_net,
                                           Y_time_new.shape[0], self.lw_dim,
                                           num_particles)

        if self.coord:
            # gp_coord_new_sample.shape # gp_coord_new_sample[v,h,sr_i,lw_i,t,num_particles]
            send = gp_coord_new_sample[:, :, 0, :, :, :].expand(
                self.V_net, self.V_net, self.H_dim, self.lw_dim,
                Y_time_new.shape[0], num_particles).transpose(0, 1)
            if self.directed:
                receive = gp_coord_new_sample[:, :, 1, :, :, :].expand(
                    self.V_net, self.V_net, self.H_dim, self.lw_dim,
                    Y_time_new.shape[0], num_particles)
            else:
                receive = gp_coord_new_sample[:, :, 0, :, :, :].expand(
                    self.V_net, self.V_net, self.H_dim, self.lw_dim,
                    Y_time_new.shape[0], num_particles)
            Y_linpred_new_sample += ((send - receive)**2).sum(
                dim=2).rsqrt().transpose(2, 3)

        if self.socpop:
            soc = gp_socpop_new_sample[:, 0, :, :, :].expand(
                self.V_net, self.V_net, self.lw_dim, Y_time_new.shape[0],
                num_particles).transpose(0, 1)
            pop = gp_socpop_new_sample[:, 1, :, :, :].expand(
                self.V_net, self.V_net, self.lw_dim, Y_time_new.shape[0],
                num_particles)
            Y_linpred_new_sample += (soc + pop).transpose(2, 3)

        # Y_linpred_new_sample.mean(dim=Y_linpred_new_sample.dim()-1)
        return Y_linpred_new_sample

    def set_data(self, Y, Y_time, H_dim=3, X=None):
        """
        Sets data for dmn models.
        """

        assert (len(Y.shape) >= 3) and (len(Y.shape) <= 4)

        # square adjacence matrices
        assert Y.shape[0] == Y.shape[1]

        self.Y = Y
        self.V_net, self.T_net = self.Y.shape[0], self.Y.shape[2]

        self.Y_link = torch.where(Y != 0, torch.ones_like(Y),
                                  torch.zeros_like(Y))

        self.K_net = Y.shape[3] if len(Y.shape) == 4 else 1

        assert self.T_net == Y_time.shape[0]
        self.Y_time = Y_time

        assert int(H_dim) >= 1
        self.H_dim = int(H_dim)

        self.X = X