class TiedLinear(torch.nn.Module): def __init__(self, in_features, output_dim, bias=False): super(TiedLinear, self).__init__() self.in_features = in_features self.output_dim = output_dim self.weight = Parameter(torch.Tensor(1, in_features)) if bias: self.bias = Parameter(torch.Tensor(1, in_features)) else: self.register_parameter('bias', None) self.reset_parameters() self.W = self.weight.expand(output_dim, -1) if self.bias is not None: self.B = self.bias.expand(output_dim, -1) def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(0)) self.weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.uniform_(-stdv, stdv) def forward(self, X, index, mask): output = X.mul(self.W) if self.bias is not None: output += self.B output = output.sum(2) output.index_add_(0, index, mask) return output
class BiasedEmbedding(nn.Module): def __init__(self, n_feat, n_dim, lv=1.0, lb=1.0): super(BiasedEmbedding, self).__init__() self.vect = nn.Embedding(n_feat, n_dim) self.bias = nn.Embedding(n_feat, 1) self.off_vect = Parameter(torch.zeros(1, n_dim)) self.mul_vect = Parameter(torch.ones(1, n_dim)) self.off_bias = Parameter(torch.zeros(1)) self.mul_bias = Parameter(torch.ones(1)) self.n_dim = n_dim self.n_feat = n_feat self.lv = lv self.lb = lb self.vect.weight.data.normal_(0, 1.0 / n_dim) self.bias.weight.data.normal_(0, 1.0 / n_dim) def __call__(self, index): assert (index.max() < self.n_feat).all() assert (index.min() >= 0).all() off_vect = self.off_vect.expand(len(index), self.n_dim).squeeze() off_bias = self.off_bias.expand(len(index), 1).squeeze() bias = off_bias + self.mul_bias * self.bias(index).squeeze() vect = off_vect + self.mul_vect * self.vect(index) return bias, vect def prior(self): loss = ((self.vect.weight**2.0).sum() * self.lv + (self.bias.weight**2.0).sum() * self.lb) return loss
class RNN(nn.Module): def __init__(self, input_size=9, hidden_size=100, output_size=9): super(RNN, self).__init__() MAX = 4000 EOS = torch.from_numpy(np.array(8 * [0] + [1])).float() self.hidden_size = hidden_size self.LSTM = nn.LSTMCell(input_size, hidden_size) self.fc = nn.Linear(hidden_size, output_size) self.activation = functional.sigmoid self.hidden_state0 = Parameter(torch.zeros(1, hidden_size)).float() self.cell_state0 = Parameter(torch.zeros(1, hidden_size)).float() self.zero_vector = Parameter(torch.zeros(MAX, 9)).float() # self.zero_vector = Parameter(EOS.expand(MAX, 9)) def step(self, input_vector, hidden_state, cell_state): hidden_state, cell_state = self.LSTM(input_vector, (hidden_state, cell_state)) return hidden_state, cell_state, self.fc(hidden_state) def forward(self, input_vectors): N = input_vectors.shape[0] T = input_vectors.shape[1] - 1 hidden_state = self.hidden_state0.expand(N, self.hidden_size) cell_state = self.cell_state0.expand(N, self.hidden_size) for t in range(T + 1): hidden_state, cell_state, _ = self.step(input_vectors[:, t, :], hidden_state, cell_state) outputs = [] for t in range(T): hidden_state, cell_state, output = self.step(self.zero_vector[:N, :], hidden_state, cell_state) outputs.append(self.activation(output.unsqueeze(2).transpose(1, 2))) return torch.cat(outputs, 1)
class BasisNet(torch.nn.Module): def __init__(self, include_G_weights): super(BasisNet, self).__init__() self.include_G_weights = include_G_weights self.psd_vec = torch.nn.Linear(2, 3, bias=False) self.psd_vec_bias = Parameter(torch.ones(3)) self.nonneg_matrix = torch.nn.Linear(2, 9, bias=False) self.nonneg_matrix_bias = Parameter(torch.ones(9)) self.antisym = torch.nn.Linear(2, 3, bias=False) self.antisym_bias = Parameter(torch.ones(3)) self.antisym_basis = torch.tensor([[[0, 1, 0], [-1, 0, 0], [0, 0, 0]], [[0, 0, 1], [0, 0, 0], [-1, 0, 0]], [[0, 0, 0], [0, 0, 1], [0, -1, 0]]]).float() self.f = torch.nn.Linear(2, 3, bias=True) if not include_G_weights: self.psd_vec.weight.data.fill_(0) self.nonneg_matrix.weight.data.fill_(0) self.antisym.weight.data.fill_(0) def forward(self, states): lambdas = states[:, 2:5] Gxu, fxu = self.get_lcps(states) lcp_slack = fxu + torch.bmm(Gxu, lambdas.unsqueeze(2)).squeeze(2) return lcp_slack def get_lcps(self, states): lambdas = states[:, 2:5] xus = states[:, 0:2] fxu = self.f(xus) if self.include_G_weights: psd_vec = self.psd_vec(xus) + self.psd_vec_bias nonneg_matrix = self.nonneg_matrix(xus) + self.nonneg_matrix_bias antisym_vec = self.antisym(xus) + self.antisym_bias else: psd_vec = self.psd_vec_bias.expand(states.shape[0], 3) nonneg_matrix = self.nonneg_matrix_bias.expand(states.shape[0], 9) antisym_vec = self.antisym_bias.expand(states.shape[0], 3) psd_term = torch.bmm(psd_vec.unsqueeze(2), psd_vec.unsqueeze(1)) nonneg_term = F.relu(nonneg_matrix.view(-1, 3, 3)) antisym_vec = antisym_vec.unsqueeze(2).unsqueeze(3) antisym_basis = self.antisym_basis.expand(antisym_vec.shape[0], 3, 3, 3) antisym_term = torch.sum(antisym_vec * antisym_basis, dim=1) Gxu = psd_term + nonneg_term + antisym_term return Gxu, fxu
class AttentionModule(nn.Module): def __init__(self, config): super(AttentionModule, self).__init__() self.config = config self.rnn = nn.LSTM(config.word_emb_dim + config.resnet_features, #self.rnn = nn.LSTM(config.word_emb_dim, config.lstm_hidden_size, config.recurrent_layers, batch_first=False, bidirectional=False, dropout=config.recurrent_dropout) self.weight_feature2attn = Parameter(torch.Tensor(config.resnet_features)) self.weight_hidden2attn = Parameter(torch.Tensor(config.lstm_hidden_size, config.resnet_fmap_size)) self.bias_attention = Parameter(torch.Tensor(config.resnet_fmap_size)) self.mlp = nn.Sequential( nn.Linear(config.lstm_hidden_size, config.mlp_hidden_size), #nn.Linear(config.lstm_hidden_size+config.resnet_features, config.mlp_hidden_size), nn.ReLU(), nn.Dropout(config.mlp_dropout), nn.Linear(config.mlp_hidden_size, 1), nn.Sigmoid() ) self.reset_parameters() def reset_parameters(self): for weight in self.parameters(): stdv = 2.0 / math.sqrt(weight.size(-1)) weight.data.uniform_(-stdv, stdv) def forward(self, rawwords_emb, image_features): words_emb, lengths = pad_packed_sequence(rawwords_emb, batch_first=False) seq_len, batch_size = words_emb.size()[0:2] hidden_state = Variable(torch.zeros(self.config.recurrent_layers, batch_size, self.config.lstm_hidden_size)) cell_state = Variable(torch.zeros(self.config.recurrent_layers, batch_size, self.config.lstm_hidden_size)) if self.config.cuda: hidden_state, cell_state = hidden_state.cuda(), cell_state.cuda() weight = self.weight_feature2attn.expand(batch_size, self.config.resnet_features).unsqueeze(2) feature2a = torch.bmm(image_features.transpose(1, 2), weight).squeeze(2) outputs = [] for i in range(seq_len): # feature2a, attn_weight are both (batch_size, 512) attn_weight = F.softmax(feature2a + torch.mm(hidden_state[-1], self.weight_hidden2attn) + self.bias_attention.expand(batch_size, self.config.resnet_fmap_size)) attn_out = torch.bmm(image_features, attn_weight.unsqueeze(2)).squeeze(2) # (bz, 512) inputs = torch.cat([words_emb[i], attn_out], 1).unsqueeze(0) output, (hidden_state, cell_state) = self.rnn(inputs, (hidden_state, cell_state)) outputs.append(output.squeeze(0)) outputs = torch.cat([outputs[lengths[i]-1][i].unsqueeze(0) for i in range(len(lengths))], 0) scores = self.mlp(outputs).squeeze(1) ''' outputs, (hidden_state, cell_state) = self.rnn(rawwords_emb) inputs = torch.cat([hidden_state[-1], image_features.mean(2).squeeze(2)], 1) scores = self.mlp(inputs).squeeze(1) ''' return scores def eval(self): self.rnn.eval() self.mlp.eval() def train(self): self.rnn.train() self.mlp.train()
def initial_state_means_for_batch(self, parameters: Parameter, num_groups: int, **kwargs) -> Tensor: """ Most children should use default. Handles rearranging of state-means based on for_batch keyword args. E.g. a discrete seasonal process w/ a state-element for each season would need to know on which season the batch starts """ return parameters.expand(num_groups, -1)
class Affine(Transform): def __init__(self, loc=0.0, scale=1.0, learnable=True): super().__init__() if not isinstance(loc, torch.Tensor): loc = torch.tensor(loc).view(1, -1) if not isinstance(scale, torch.Tensor): scale = torch.tensor(scale).view(1, -1) self.loc = loc.float() self.scale = scale.float() self.n_dims = len(loc) if learnable: self.loc = Parameter(self.loc) self.scale = Parameter(self.scale) def forward(self, x): return self.loc + self.scale * x def inverse(self, y): return (y - self.loc) / self.scale def log_abs_det_jacobian(self, x, y): return torch.log(torch.abs(self.scale.expand(x.size()))).sum(-1) def get_parameters(self): return { 'type': 'affine', 'loc': self.loc.detach().numpy(), 'scale': self.scale.detach().numpy() }
class SqueezedAttFeatTrans(nn.Module): def __init__(self, config, name): super(SqueezedAttFeatTrans, self).__init__() self.config = config self.name = name self.in_feat_dim = config.in_feat_dim self.num_attractors = config.num_attractors # Disable feature squeezing in in_ator_trans. config1 = copy.copy(config) config1.feat_dim = config1.in_feat_dim config1.num_modes = 1 self.in_ator_trans = CrossAttFeatTrans(config1, name + '-in-squeeze') self.ator_out_trans = CrossAttFeatTrans(config, name + '-squeeze-out') self.attractors = Parameter( torch.randn(1, self.num_attractors, self.in_feat_dim)) def forward(self, in_feat, attention_mask=None): # in_feat: [B, 196, 1792] batch_size = in_feat.shape[0] batch_attractors = self.attractors.expand(batch_size, -1, -1) new_batch_attractors = self.in_ator_trans(batch_attractors, in_feat) out_feat = self.ator_out_trans(in_feat, new_batch_attractors) self.attention_scores = self.ator_out_trans.attention_scores return out_feat
class ReparamNormal_Mu_Logvar(ReparamNormal): def __init__(self, output_dim=2, **kwargs): super().__init__() mu = kwargs.get('mu', torch.zeros(output_dim)) self.mu_param = Parameter(mu.unsqueeze(0)) logvar = kwargs.get('logvar', torch.zeros(output_dim)) self.logvar_param = Parameter(logvar.unsqueeze(0)) def forward(self, x): s = x.size(0) o = self.mu_param.size(1) return self.mu_param.expand(s, o), self.logvar_param.expand(s, o) def __repr__(self): return super().__repr__() + \ '(_ -> {})'.format(self.mu_param.size(0))
class ConstantDiscountRateLayer(DiscountLayer): def __init__(self, forward_rate, time_len=MAX_YR_LEN): Module.__init__(self) self.time_len = time_len if isinstance(forward_rate, float): forward_rate = torch.tensor([forward_rate]) self._forward_rate = Parameter(forward_rate) self.forward_rate = self._forward_rate.expand(time_len)
class Rnn(Module): """ A BiLSTM or BiGRU """ def __init__(self, input_dim, hidden_dim, cell_type=LSTM, gpu=False): super(Rnn, self).__init__() self.hidden_dim = hidden_dim self.to_cuda = to_cuda(gpu) self.input_dim = input_dim self.cell_type = cell_type self.rnn = self.cell_type(input_size=self.input_dim, hidden_size=hidden_dim, num_layers=1, bidirectional=True) self.num_directions = 2 # We're a *bi*LSTM self.start_hidden_state = \ Parameter(self.to_cuda( torch.randn(self.num_directions, 1, self.hidden_dim) )) self.start_cell_state = \ Parameter(self.to_cuda( torch.randn(self.num_directions, 1, self.hidden_dim) )) def forward(self, batch, debug=0, dropout=None): """ Run a biLSTM over the batch of docs, return their hidden states (padded). """ b = len(batch.docs) docs_vectors = [ torch.index_select(batch.embeddings_matrix, 1, doc).t() for doc in batch.docs ] # Assumes/requires that `batch.docs` is sorted by decreasing doc length. # This gets done in `chunked_sorted`. packed = pack_padded_sequence(torch.stack(docs_vectors, dim=1), lengths=list(batch.doc_lens)) # run the biLSTM starts = (self.start_hidden_state.expand(self.num_directions, b, self.hidden_dim).contiguous(), self.start_cell_state.expand(self.num_directions, b, self.hidden_dim).contiguous()) outs, _ = self.rnn(packed, starts) return outs
class Attention(nn.Module): """pointer network attention mechanism""" def __init__(self, input_dim, hidden_dim): super(Attention, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.d_linear = nn.Linear(input_dim, hidden_dim) self.e_linear = nn.Linear(input_dim, hidden_dim) self.V = Parameter(torch.FloatTensor(hidden_dim), requires_grad=True) self.inf = Parameter( torch.FloatTensor([float("-inf")]), requires_grad=False) # aviod grad calculation certain steps def init_inf(self, size): self.inf = self.inf.expand(*size) def forward(self, hidden, context, mask): """ :type hidden: torch.Tensor, [batch_size, hidden_dim] :type context: torch.Tensor, [batch_size, seq_len, hidden_dim] :type mask: torch.Tensor, [batch_size, seq_len] """ batch_size = context.shape[0] seq_len = context.shape[1] di = self.d_linear(hidden).expand(-1, -1, seq_len) ei = self.e_linear(context) # [batch, seq_len, hidden_dim] ui = torch.bmm(self.V.expand(batch_size, 1, -1), F.tanh(di + ei).permute(0, 2, 1)).squeeze( 1) # [batch, seq_len] # mask poster unit for softmax ui[mask] = self.inf[mask] alpha = F.softmax(ui) # [batch, seq_len] attention_hidden_state = torch.bmm(ei.permute(0, 2, 1), alpha.unsqueeze(2)).squeeze( 2) # [batch, hidden_dim] return alpha, attention_hidden_state
class LSTMCell(nn.Module): def __init__(self, input_size, hidden_size, bias=True, layernorm=False, dropoutr=0): super(LSTMCell, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.weight = Parameter( torch.Tensor(input_size + hidden_size, 3 * hidden_size)) self.bias = Parameter(torch.Tensor(3 * hidden_size)) if bias else None self.layernorm = nn.ModuleList( [nn.LayerNorm(hidden_size) for _ in range(3)]) if layernorm else None self.dropoutr = nn.Dropout(dropoutr) if dropoutr > 0 else None self.reset_parameters() def reset_parameters(self): stdv = 1.0 / math.sqrt(self.hidden_size) for weight in self.parameters(): weight.data.uniform_(-stdv, stdv) if self.layernorm: for ln in self.layernorm: ln.reset_parameters() def forward(self, input, hx=None): # input: (batch, input_size) # hx: tuple of (batch, hidden_size), (batch, hidden_size) if hx is None: hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) hx = (hx, hx) h_0, c_0 = hx pre_gate = torch.mm(torch.cat((input, h_0), 1), self.weight) if self.bias: pre_gate += self.bias.expand(0) f, o, g = pre_gate.split(self.hidden_size, 1) if self.layernorm: f = self.layernorm[0](f) o = self.layernorm[1](o) g = self.layernorm[2](g) f = torch.sigmoid(f + 1.) # _forget_bias i = 1. - f # input and forget gates are coupled o = torch.sigmoid(o) g = torch.tanh(g) if self.dropoutr: g = self.dropoutr(g) # recurrent dropout without memory loss c_1 = f * c_0 + i * g h_1 = o * torch.tanh(c_1) return h_1, c_1
class TiedLinear(torch.nn.Module): """ TiedLinear is a linear layer with shared parameters for features between (output) classes that takes as input a tensor X with dimensions (batch size) X (output_dim) X (in_features) where: output_dim is the disired output dimension/# of classes in_features are the features with shared weights across the classes """ def __init__(self, in_features, output_dim, bias=False): super(TiedLinear, self).__init__() self.in_features = in_features self.output_dim = output_dim self.weight = Parameter(torch.Tensor(1,in_features)) if bias: self.bias = Parameter(torch.Tensor(1,in_features)) else: self.register_parameter('bias', None) self.reset_parameters() # Broadcast parameters to the correct matrix dimensions for matrix # multiplication: this does NOT create new parameters: i.e. each # row of in_features of parameters are connected and will adjust # to the same values. self.W = self.weight.expand(output_dim, -1) if self.bias is not None: self.B = self.bias.expand(output_dim, -1) def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(0)) self.weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.uniform_(-stdv, stdv) def forward(self, X, index, mask): output = X.mul(self.W) if self.bias is not None: output += self.B output = output.sum(2) # Add our mask so that invalid domain classes for a given variable/VID # has a large negative value, resulting in a softmax probability # of de facto 0. output.index_add_(0, index, mask) return output
class Isotropy(Kernel): """ Base class for a family of isotropic covariance kernels which are functions of the distance :math:`|x-z|/l`, where :math:`l` is the length-scale parameter. By default, the parameter ``lengthscale`` has size 1. To use the isotropic version (different lengthscale for each dimension), make sure that ``lengthscale`` has size equal to ``input_dim``. :param torch.Tensor lengthscale: Length-scale parameter of this kernel. """ def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None): super(Isotropy, self).__init__(input_dim, active_dims) variance = torch.tensor(1.) if variance is None else variance self.variance = Parameter(variance) self.set_constraint("variance", constraints.positive) lengthscale = torch.tensor(1.) if lengthscale is None else lengthscale self.lengthscale = Parameter(lengthscale) self.set_constraint("lengthscale", constraints.positive) def _square_scaled_dist(self, X, Z=None): r""" Returns :math:`\|\frac{X-Z}{l}\|^2`. """ if Z is None: Z = X X = self._slice_input(X) Z = self._slice_input(Z) if X.size(1) != Z.size(1): raise ValueError("Inputs must have the same number of features.") scaled_X = X / self.lengthscale scaled_Z = Z / self.lengthscale X2 = (scaled_X**2).sum(1, keepdim=True) Z2 = (scaled_Z**2).sum(1, keepdim=True) XZ = scaled_X.matmul(scaled_Z.t()) r2 = X2 - 2 * XZ + Z2.t() return r2.clamp(min=0) def _scaled_dist(self, X, Z=None): r""" Returns :math:`\|\frac{X-Z}{l}\|`. """ return _torch_sqrt(self._square_scaled_dist(X, Z)) def _diag(self, X): """ Calculates the diagonal part of covariance matrix on active features. """ return self.variance.expand(X.size(0))
class ConstantMean(Mean): def __init__(self, constant=None ): super(ConstantMean, self).__init__() constant = torch.zeros(1) if constant is None else constant self.constant = Parameter(constant) def forward(self, X): return self.constant.expand(X.size(0))
class Constant(Kernel): r""" Implementation of Constant kernel: :math:`k(x, z) = \sigma^2.` """ def __init__(self, input_dim, variance=None, active_dims=None): super(Constant, self).__init__(input_dim, active_dims) variance = torch.tensor(1.) if variance is None else variance self.variance = Parameter(variance) self.set_constraint("variance", constraints.positive) def forward(self, X, Z=None, diag=False): if diag: return self.variance.expand(X.size(0)) if Z is None: Z = X return self.variance.expand(X.size(0), Z.size(0))
class WhiteNoise(Kernel): r""" Implementation of WhiteNoise kernel: :math:`k(x, z) = \sigma^2 \delta(x, z),` where :math:`\delta` is a Dirac delta function. """ def __init__(self, input_dim, variance=None, active_dims=None): super(WhiteNoise, self).__init__(input_dim, active_dims) variance = torch.tensor(1.) if variance is None else variance self.variance = Parameter(variance) self.set_constraint("variance", constraints.positive) def forward(self, X, Z=None, diag=False): if diag: return self.variance.expand(X.size(0)) if Z is None: return self.variance.expand(X.size(0)).diag() else: return X.data.new_zeros(X.size(0), Z.size(0))
class Periodic(Kernel): r""" Implementation of Periodic kernel: :math:`k(x,z)=\sigma^2\exp\left(-2\times\frac{\sin^2(\pi(x-z)/p)}{l^2}\right),` where :math:`p` is the ``period`` parameter. References: [1] `Introduction to Gaussian processes`, David J.C. MacKay :param torch.Tensor lengthscale: Length scale parameter of this kernel. :param torch.Tensor period: Period parameter of this kernel. """ def __init__(self, input_dim, variance=None, lengthscale=None, period=None, active_dims=None): super(Periodic, self).__init__(input_dim, active_dims) variance = torch.tensor(1.) if variance is None else variance self.variance = Parameter(variance) self.set_constraint("variance", constraints.positive) lengthscale = torch.tensor(1.) if lengthscale is None else lengthscale self.lengthscale = Parameter(lengthscale) self.set_constraint("lengthscale", constraints.positive) period = torch.tensor(1.) if period is None else period self.period = Parameter(period) self.set_constraint("period", constraints.positive) def forward(self, X, Z=None, diag=False): if diag: return self.variance.expand(X.size(0)) if Z is None: Z = X X = self._slice_input(X) Z = self._slice_input(Z) if X.size(1) != Z.size(1): raise ValueError("Inputs must have the same number of features.") d = X.unsqueeze(1) - Z.unsqueeze(0) scaled_sin = torch.sin(math.pi * d / self.period) / self.lengthscale return self.variance * torch.exp(-2 * (scaled_sin**2).sum(-1))
class PointerNet(nn.Module): def __init__(self, vocab_size, embed_dim, hidden_dim, layers, dropout): super(PointerNet, self).__init__() self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.layers = layers self.dropout = dropout self.embedding = nn.Embedding(vocab_size, embed_dim) self.encoder = Encoder(embed_dim, hidden_dim, layers, dropout) self.decoder = Decoder(embed_dim, hidden_dim) self.first_decoder_input = Parameter(torch.FloatTensor(embed_dim), requires_grad=False) nn.init.uniform(self.first_decoder_input, -1, 1) def forward(self, inputs): """ :type inputs: torch.Tensor, [batch, seq_len] """ batch_size = inputs.shape[0] first_decoder_input = self.first_decoder_input.expand( batch_size, -1) # [batch, embed_dim] enc_inputs = self.embedding(inputs) # [batch, seq_len, embed_dim] first_encoder_hidden = self.encoder.init_hidden( inputs) # Tuple(h0, c0) # =============== # encoder procedure # =============== enc_out, enc_hidden = self.encoder( enc_inputs, first_encoder_hidden) # [batch, seq_len, layers * hidden_dim] # =============== # decoder procedure # =============== first_decoder_hidden = (enc_hidden[0][-1], enc_hidden[1][-1]) (outputs, pointers), dec_hidden = self.decoder(enc_inputs, first_decoder_input, first_decoder_hidden, enc_out) return outputs, pointers
class LcpStructuredNet(torch.nn.Module): def __init__(self, warm_start, include_G_weights): super(LcpStructuredNet, self).__init__() self.include_G_weights = include_G_weights self.f = torch.nn.Linear(2, 3, bias=False) self.f_bias = torch.nn.Parameter(torch.ones(3)) self.G = torch.nn.Linear(2, 9, bias=False) self.G_bias = torch.nn.Parameter(torch.ones(9)) # Correct dynamics solution if warm_start: self.G.weight.data.fill_(0) self.G.weight.data = self.add_noise(self.G.weight.data) self.G_bias = Parameter(self.add_noise( torch.tensor([1, -1, 1, -1, 1, 1, -1, -1, 0]).float())) self.f.weight = Parameter(self.add_noise( torch.tensor([[1, 1], [-1, -1], [0, 0]]).float())) self.f_bias = Parameter(self.add_noise( torch.tensor([0, 0, 1]).float())) else: torch.nn.init.xavier_uniform_(self.f.weight) torch.nn.init.xavier_uniform_(self.G.weight) if not include_G_weights: self.G.weight.data.fill_(0) def add_noise(self, tensor): m = torch.distributions.normal.Normal(0, 0.1) return tensor + m.sample(tensor.shape).float() def forward(self, states): lambdas = states[:, 2:5] xus = states[:, 0:2] fxu = self.f(xus) + self.f_bias if self.include_G_weights: Gxu = (self.G(xus) + self.G_bias).view(-1, 3, 3) else: Gxu = (self.G_bias.expand(states.shape[0], 9)).view(-1, 3, 3) lcp_slack = fxu + torch.bmm(Gxu,lambdas.unsqueeze(2)).squeeze(2) return lcp_slack
class GumbelMF(nn.Module): def __init__(self, n_users, n_items, n_dim, n_obs, lub=1., lib=1., luv=1., liv=1., tau=0.8, loss=nn.MSELoss): super(GumbelMF, self).__init__() self.embed_user = BiasedEmbedding(n_users, n_dim, lb=lub, lv=luv) self.embed_item = BiasedEmbedding(n_items, n_dim, lb=lib, lv=liv) self.glob_bias = Parameter(torch.FloatTensor([0.01])) self.n_obs = n_obs self.lossf = loss() self.tau = tau def forward(self, u, i): u, i = u.squeeze(), i.squeeze() bias = self.glob_bias.expand(len(u), 1).squeeze() bu, lu = self.embed_user(u) bi, li = self.embed_item(i) if self.training: du, di = gumbel_softmax_correlated(lu, li) # du = gumbel_softmax(lu, self.tau) # di = gumbel_softmax(li, self.tau) else: du = F.softmax(lu) di = F.softmax(li) intx = hellinger(du, di) logodds = (bias + bi + bu + intx).squeeze() return logodds def loss(self, prediction, target): # average likelihood loss per example ex_llh = self.lossf(prediction, target) # regularization penalty summed over whole model epoch_reg = (self.embed_user.prior() + self.embed_item.prior()) # penalty should be computer for a single example ex_reg = epoch_reg * 1.0 / self.n_obs return ex_llh + ex_reg
class MFPoincare(nn.Module): def __init__(self, n_users, n_items, n_dim, n_obs): super(MFPoincare, self).__init__() self.embed_user = BiasedEmbedding(n_users, n_dim) self.embed_item = BiasedEmbedding(n_items, n_dim) self.glob_bias = Parameter(torch.Tensor(1, 1)) self.n_obs = n_obs def forward(self, u, i): bias = self.glob_bias.expand(len(u), 1).squeeze() bu, vu = self.embed_user(u) bi, vi = self.embed_item(i) dist = poincare_distance(vu, vi) logodds = bias + bi + bu + dist return logodds def loss(self, prediction, target): # Don't know how to regularize poincare space yet! llh = F.binary_cross_entropy_with_logits(prediction, target) return llh
class P_AnalysisDict(nn.Module): def __init__(self, col_synthesis_dict, col_sample, row_sample, num_sample, device): super(P_AnalysisDict, self).__init__() self.col_synthesis_dict = col_synthesis_dict self.num_views = col_sample self.num_sample = num_sample self.row_sample = row_sample self.device = device self.Analysis_Dict = Parameter( pt.Tensor(col_synthesis_dict, row_sample)) self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.Analysis_Dict.size(0)) self.Analysis_Dict.data.uniform_(-stdv, stdv) def forward(self, X): #get the rebuilt value of sparse code P = self.Analysis_Dict.expand(X.size(0), self.col_synthesis_dict, self.row_sample) out = pt.bmm(P, X) return out
class P_SynthesisDict(nn.Module): def __init__(self, col_SynthesisDict, col_sample, row_sample, num_sample, device): super(P_SynthesisDict, self).__init__() self.col_synthesis_dict = col_SynthesisDict self.num_views = col_sample self.num_sample = num_sample self.row_sample = row_sample self.device = device self.Synthesis_Dict = Parameter( pt.Tensor(row_sample, col_SynthesisDict)) self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.Synthesis_Dict.size(0)) # axis=0计算每一列的标准差 self.Synthesis_Dict.data.uniform_(-stdv, stdv) def forward(self, Sparse_Code): # obtain the rebuilt value of sample_x B = self.Synthesis_Dict.expand(self.num_sample, self.row_sample, self.col_synthesis_dict) out = pt.bmm(B, Sparse_Code) return out
class MF(nn.Module): def __init__(self, n_users, n_items, n_dim, n_obs, lub=1., lib=1., luv=1., liv=1., loss=nn.MSELoss): super(MF, self).__init__() self.embed_user = BiasedEmbedding(n_users, n_dim, lb=lub, lv=luv) self.embed_item = BiasedEmbedding(n_items, n_dim, lb=lib, lv=liv) self.glob_bias = Parameter(torch.FloatTensor([0.01])) self.n_obs = n_obs self.lossf = loss() def forward(self, u, i): u, i = u.squeeze(), i.squeeze() bias = self.glob_bias.expand(len(u), 1).squeeze() bu, vu = self.embed_user(u) bi, vi = self.embed_item(i) intx = (vu * vi).sum(dim=1) logodds = (bias + bi + bu + intx).squeeze() return logodds def loss(self, prediction, target): # average likelihood loss per example ex_llh = self.lossf(prediction, target) if self.training: # regularization penalty summed over whole model epoch_reg = (self.embed_user.prior() + self.embed_item.prior()) # penalty should be computer for a single example ex_reg = epoch_reg * 1.0 / self.n_obs else: ex_reg = 0.0 return ex_llh + ex_reg
class MFPoly2(nn.Module): def __init__(self, n_users, n_items, n_dim, n_obs, lub=1., lib=1., luv=1., liv=1., loss=nn.MSELoss): super(MFPoly2, self).__init__() self.embed_user = BiasedEmbedding(n_users, n_dim, lb=lub, lv=luv) self.embed_item = BiasedEmbedding(n_items, n_dim, lb=lib, lv=liv) self.glob_bias = Parameter(torch.FloatTensor([0.01])) # maps from a scalar (dim=2) of frame and frame^2 # effectively fitting a quadratic polynomial to the frame number # to a scalar log odds (dim=1) self.poly = nn.Linear(2, 1) self.n_obs = n_obs self.lossf = loss() def forward(self, u, i, f): u, i = u.squeeze(), i.squeeze() bias = self.glob_bias.expand(len(u), 1).squeeze() bu, vu = self.embed_user(u) bi, vi = self.embed_item(i) intx = (vu * vi).sum(dim=1) frame = [f.unsqueeze(0), f.unsqueeze(0)**2.0] effect = self.poly(torch.t(torch.log(torch.cat(frame)))).squeeze() logodds = (bias + bi + bu + intx + effect).squeeze() return logodds def loss(self, prediction, target): # average likelihood loss per example ex_llh = self.lossf(prediction, target) if self.training: # regularization penalty summed over whole model epoch_reg = (self.embed_user.prior() + self.embed_item.prior()) # penalty should be computer for a single example ex_reg = epoch_reg * 1.0 / self.n_obs else: ex_reg = 0.0 return ex_llh + ex_reg
class Decoder(nn.Module): """decoder layer * first decoder hidden * first decoder input * each step encoder inputs * each step encoder outputs received for conducting rnn layers and pointer network attention then give each decoder step pointer index and corresponding attention alphas """ def __init__(self, embed_dim, hidden_dim): super(Decoder, self).__init__() self.embed_dim = embed_dim self.hidden_dim = hidden_dim # lstm layer self.input2hidden = nn.Linear(embed_dim, hidden_dim * 4) self.hidden2hidden = nn.Linear(hidden_dim, hidden_dim * 4) self.hidden2out = nn.Linear(2 * hidden_dim, hidden_dim) # pointer attention layer self.attention = Attention(hidden_dim, hidden_dim) self.mask = Parameter(torch.ones(1), requires_grad=False) self.runner = Parameter(torch.zeros(1), requires_grad=False) # record each step index def forward(self, enc_inputs, dec_inputs, hidden, context): """ :type enc_inputs: torch.Tensor, each encoder step inputs :type dec_inputs: torch.Tensor, first decoder inputs fed :type hidden: torch.Tensor, first decoder hidden fed :type context: torch.Tensor, each encoder step outputs """ batch_size = enc_inputs.shape[0] seq_len = enc_inputs.shape[1] mask = self.mask.expand(batch_size, seq_len) self.attention.init_inf(mask.size()) runner = self.runner.expand(batch_size, seq_len) for i in range(seq_len): runner.index_fill_(1, torch.tensor([i]), i) outputs = [] pointers = [] # ============== # decoder for each step # ============== for _ in range(seq_len): g_o, (c_t, h_t) = self.lstm_step(dec_inputs, hidden) h_t, alpha = self.attention_step(h_t, context, mask) hidden = (h_t, c_t) masked_alpha = alpha * mask # [batch_size, seq_len] max_prob, max_index = masked_alpha.max(1) # update mask seen_pointer = (runner == max_index.unsqueeze(1).expand( -1, alpha.shape[1])).float() mask = mask * (1 - seen_pointer) # update dec_inputs from mask embedding_mask = seen_pointer.unsqueeze(2).expand( -1, -1, self.embed_dim) dec_inputs = enc_inputs[embedding_mask].view( batch_size, self.embed_dim) outputs.append(alpha.unsqueeze(0)) pointers.append(max_index.unsqueeze(1)) outputs = torch.cat(outputs, dim=0).permute(1, 0, 2) # [d_len, batch, seq_len] pointers = torch.cat(pointers, dim=0).permute(1, 0, 2) # [d_len, batch, 1] return (outputs, pointers), hidden def lstm_step(self, x, hidden): """single lstm cell forward""" h, c = hidden gates_cells = self.input2hidden(x) + self.hidden2hidden(hidden) g_i, g_f, g_c, g_o = gates_cells.chunk(4, 1) g_i = F.sigmoid(g_i) g_f = F.sigmoid(g_f) g_c = F.tanh(g_c) g_o = F.sigmoid(g_o) c_t = (g_f * c) + (g_i * g_c) h_t = g_o * F.tanh(c_t) return g_o, (c_t, h_t) def attention_step(self, h_t, context, mask): """single pointer network attention forward""" alpha, attention_hidden_state = self.attention(h_t, context, torch.eq(mask, 0)) attention_hidden_state = F.tanh( self.hidden2out(torch.cat((attention_hidden_state, h_t), 1))) # seems a fusion layer, but hidden_state from attention layer could used directly return attention_hidden_state, alpha
class ViT(Module): """Vision transformer as per https://arxiv.org/abs/2010.11929.""" @staticmethod def get_params(): return { "image_size": cfg.TRAIN.IM_SIZE, "patch_size": cfg.VIT.PATCH_SIZE, "stem_type": cfg.VIT.STEM_TYPE, "c_stem_kernels": cfg.VIT.C_STEM_KERNELS, "c_stem_strides": cfg.VIT.C_STEM_STRIDES, "c_stem_dims": cfg.VIT.C_STEM_DIMS, "n_layers": cfg.VIT.NUM_LAYERS, "n_heads": cfg.VIT.NUM_HEADS, "hidden_d": cfg.VIT.HIDDEN_DIM, "mlp_d": cfg.VIT.MLP_DIM, "cls_type": cfg.VIT.CLASSIFIER_TYPE, "num_classes": cfg.MODEL.NUM_CLASSES, } @staticmethod def check_params(params): p = params err_str = "Input shape indivisible by patch size" assert p["image_size"] % p["patch_size"] == 0, err_str assert p["stem_type"] in ["patchify", "conv"], "Unexpected stem type" assert p["cls_type"] in ["token", "pooled"], "Unexpected classifier mode" if p["stem_type"] == "conv": err_str = "Conv stem layers mismatch" assert len(p["c_stem_dims"]) == len(p["c_stem_strides"]), err_str assert len(p["c_stem_strides"]) == len( p["c_stem_kernels"]), err_str err_str = "Stem strides unequal to patch size" assert p["patch_size"] == np.prod(p["c_stem_strides"]), err_str err_str = "Stem output dim unequal to hidden dim" assert p["c_stem_dims"][-1] == p["hidden_d"], err_str def __init__(self, params=None): super(ViT, self).__init__() p = ViT.get_params() if not params else params ViT.check_params(p) if p["stem_type"] == "patchify": self.stem = ViTStemPatchify(3, p["hidden_d"], p["patch_size"]) elif p["stem_type"] == "conv": ks, ws, ss = p["c_stem_kernels"], p["c_stem_dims"], p[ "c_stem_strides"] self.stem = ViTStemConv(3, ks, ws, ss) seq_len = (p["image_size"] // cfg.VIT.PATCH_SIZE)**2 if p["cls_type"] == "token": self.class_token = Parameter(torch.zeros(1, 1, p["hidden_d"])) seq_len += 1 else: self.class_token = None self.pos_embedding = Parameter(torch.zeros(seq_len, 1, p["hidden_d"])) self.encoder = ViTEncoder(p["n_layers"], p["hidden_d"], p["n_heads"], p["mlp_d"]) self.head = ViTHead(p["hidden_d"], p["num_classes"]) init_weights_vit(self) def forward(self, x): # (n, c, h, w) -> (n, hidden_d, n_h, n_w) x = self.stem(x) # (n, hidden_d, n_h, n_w) -> (n, hidden_d, (n_h * n_w)) x = x.reshape(x.size(0), x.size(1), -1) # (n, hidden_d, (n_h * n_w)) -> ((n_h * n_w), n, hidden_d) x = x.permute(2, 0, 1) if self.class_token is not None: # Expand the class token to the full batch class_token = self.class_token.expand(-1, x.size(1), -1) x = torch.cat([class_token, x], dim=0) x = x + self.pos_embedding x = self.encoder(x) # `token` or `pooled` features for classification x = x[0, :, :] if self.class_token is not None else x.mean(dim=0) return self.head(x) @staticmethod def complexity(cx, params=None): """Computes model complexity. If you alter the model, make sure to update.""" p = ViT.get_params() if not params else params ViT.check_params(p) if p["stem_type"] == "patchify": cx = ViTStemPatchify.complexity(cx, 3, p["hidden_d"], p["patch_size"]) elif p["stem_type"] == "conv": ks, ws, ss = p["c_stem_kernels"], p["c_stem_dims"], p[ "c_stem_strides"] cx = ViTStemConv.complexity(cx, 3, ks, ws, ss) seq_len = (p["image_size"] // cfg.VIT.PATCH_SIZE)**2 if p["cls_type"] == "token": seq_len += 1 cx["params"] += p["hidden_d"] # Params of position embeddings cx["params"] += seq_len * p["hidden_d"] cx = ViTEncoder.complexity(cx, p["n_layers"], p["hidden_d"], p["n_heads"], p["mlp_d"], seq_len) cx = ViTHead.complexity(cx, p["hidden_d"], p["num_classes"]) return cx
class dmn(Parameterized): # this is equivalent to # pyro.contrib.gp.models.model.GPModel r""" Base class for Dynamic Networks using the Latent Space representation. Each node :math:`i` has four Gaussian Processes associated. 1. Location in the latent social space for link propensity (probability of being connected) 2. Location in the latent social space for edge weight (size of the connection) 3. Sociability of the node 3. Popularity of the node There are two ways to train the DMN model: + Using an MCMC algorithm + Using a variational inference on the pair :meth:`model`, :meth:`guide`: """ def __init__(self, edgelist, H_dim=3, X=None, weighted=True, directed=True, coord=False, socpop=True, jitter=1e-6, whiten=False): super(dmn, self).__init__() Y, [all_nodes, Y_time, all_layers] = pydmn.util.edgelist_to_tensor(edgelist=edgelist) self.all_layers = all_layers self.set_data(Y=Y, Y_time=Y_time, H_dim=H_dim, X=X) self.whiten = whiten self.jitter = jitter self.weighted = weighted self.directed = directed self.coord = coord self.socpop = socpop # Dimensions of objects change if the network is weighted and/or directed self.lw_dim = (2 if weighted else 1) self.sr_dim = (2 if directed else 1) self.kernel = Parameterized() # kernels modules will be added here self.gp = Parameterized() # GP modules will be added here if self.weighted: sigma_Y = torch.ones(self.K_net) self.sigma_Y = Parameter(sigma_Y) self.set_constraint("sigma_Y", constraints.positive) ### Latent locations ### if coord: # Kernels for GP of latent locations # # Assume that the kernel is the same for all agents, i.e. the dynamics in the latent space is the same # self.kernel.coord = Parameterized() self.kernel.coord.link = gp.kernels.RBF(input_dim=1) if self.weighted: self.kernel.coord.weight = gp.kernels.RBF(input_dim=1) ## Dynamic latent locations ## self.gp.coord = Parameterized() # GP location (without mean fun) # # dimensions: (V_net, H_dim, sr_dim, lw_dim, T_net) = (1 per agent, 1 per lat dim, send & rec, link & weight, Time ) gp_coord_loc = torch.randn(self.V_net, self.H_dim, self.sr_dim, self.lw_dim, self.T_net) self.gp.coord.loc = Parameter(gp_coord_loc) # Mean function: constant # # dimensions: (V_net, H_dim, sr_dim, lw_dim) = (1 per agent, 1 per lat dim, send & rec, link & weight) gp_coord_loc_mean = torch.randn(self.V_net, self.H_dim, self.sr_dim, self.lw_dim) self.gp.coord.loc_mean = Parameter(gp_coord_loc_mean) # Lower Cholesky of the Covariance matrix of latent coordinates # # dimensions: (V_net, H_dim, sr_dim, lw_dim, T_net, T_net) = (1 per agent, 1 per lat dim, send & rec, link & weight, Time, Time ) gp_coord_cov_tril_unconst = torch.diag_embed( torch.ones(self.V_net, self.H_dim, self.sr_dim, self.lw_dim, self.T_net)) self.gp.coord.cov_tril_unconst = Parameter( gp_coord_cov_tril_unconst) self.gp.coord.cov_tril = torch.stack([ torch.stack([ torch.stack([ torch.stack([ transform_to(constraints.lower_cholesky)( self.gp.coord.cov_tril_unconst[v, h, sr_i, lw_i, :, :]) for lw_i in range(self.lw_dim) ]) for sr_i in range(self.sr_dim) ]) for h in range(self.H_dim) ]) for v in range(self.V_net) ]) else: self.kernel.coord = None self.gp.coord = None ### Sociability and Popularity ### if socpop: # Kernels for GP of Sociability and Popularity params # # Assume that the kernel is the same for sociability and popularity # self.kernel.socpop = Parameterized() self.kernel.socpop.link = gp.kernels.RBF(input_dim=1) if self.weighted: self.kernel.socpop.weight = gp.kernels.RBF(input_dim=1) ## Dynamic Sociability and Popularity ## self.gp.socpop = Parameterized() # dimensions: (V_net, 2, lw_dim, T_net) = (1 per agent, soc & pop, link & weight, Time ) gp_socpop_loc = torch.randn(self.V_net, 2, self.lw_dim, self.T_net) self.gp.socpop.loc = Parameter(gp_socpop_loc) # Mean function: constant # gp_socpop_loc_mean = torch.randn(self.V_net, 2, self.lw_dim) self.gp.socpop.loc_mean = Parameter(gp_socpop_loc_mean) # GP.socpop.Cov_tril: Lower Cholesky of the Covariance matrix of latent socpopinates # dimensions: (V_net, 2, lw_dim, T_net, T_net) = (1 per agent, soc & pop, link & weight, Time, Time ) gp_socpop_cov_tril_unconst = torch.diag_embed( torch.ones(self.V_net, 2, self.lw_dim, self.T_net)) self.gp.socpop.cov_tril_unconst = Parameter( gp_socpop_cov_tril_unconst) self.gp.socpop.cov_tril = torch.stack([ torch.stack([ torch.stack([ transform_to(constraints.lower_cholesky)( self.gp.socpop.cov_tril_unconst[v, sp_i, lw_i, :, :]) for lw_i in range(self.lw_dim) ]) for sp_i in range(2) ]) for v in range(self.V_net) ]) else: self.kernel.socpop = None self.gp.socpop = None # @autoname.scope(prefix="DMN") # generates error def model(self): self.set_mode("model") # Sample the coordinates # dimensions: (V_net, H_dim, sr_dim, lw_dim, T_net) = (1 per agent, 1 per lat dim, send & rec, link & weight, Time ) if self.coord: # Calculates lower cholesky for all soc, pop Kff = [ eval('self.kernel.coord.' + ['link', 'weight'][lw_i])( self.Y_time).contiguous() for lw_i in range(self.lw_dim) ] for lw_i in range(self.lw_dim): Kff[lw_i].view( -1)[::self.T_net + 1] += self.jitter # add jitter to the diagonal Lff_coord = [Kff[lw_i].cholesky() for lw_i in range(self.lw_dim)] # Gaussian process for the sociability and popularity # gp_coord = torch.stack([ torch.stack([ torch.stack([ torch.stack([ pydmn.util.GP_sample( name=f'f_coord_v{v}_h{h}_sr{sr_i}_lw{lw_i}', X=self.Y_time, f_loc=self.gp.coord.loc[v, h, sr_i, lw_i, :], f_loc_mean=self.gp.coord.loc_mean[v, h, sr_i, lw_i], f_scale_tril=self.gp.coord.cov_tril[ v, h, sr_i, lw_i, :, :], Lff=Lff_coord[lw_i], whiten=self.whiten) for lw_i in range(self.lw_dim) ]) for sr_i in range(self.sr_dim) ]) for h in range(self.H_dim) ]) for v in range(self.V_net) ]) # Sample the Sociability and Popularity ## # dimensions: (V_net, 2, lw_dim, T_net) = (1 per agent, soc & pop, link & weight, Time ) if self.socpop: # Calculates lower cholesky for all soc, pop Kff = [ eval('self.kernel.socpop.' + ['link', 'weight'][lw_i])( self.Y_time).contiguous() for lw_i in range(self.lw_dim) ] for lw_i in range(self.lw_dim): Kff[lw_i].view( -1)[::self.T_net + 1] += self.jitter # add jitter to the diagonal Lff_socpop = [Kff[lw_i].cholesky() for lw_i in range(self.lw_dim)] # Gaussian process for the sociability and popularity # gp_socpop = torch.stack([ torch.stack([ torch.stack([ pydmn.util.GP_sample( name= f'f_socpop_v{v}_{["soc","pop"][sp_i]}_lw{lw_i}', X=self.Y_time, f_loc=self.gp.socpop.loc[v, sp_i, lw_i, :], f_loc_mean=self.gp.socpop.loc_mean[v, sp_i, lw_i], f_scale_tril=self.gp.socpop.cov_tril[v, sp_i, lw_i, :, :], Lff=Lff_socpop[lw_i], whiten=self.whiten) for lw_i in range(self.lw_dim) ]) for sp_i in range(2) ]) for v in range(self.V_net) ]) ### Calculate Linear Predictors ### Y_linpred = torch.zeros(self.V_net, self.V_net, self.T_net, self.lw_dim) # identifies diagonal elements, which will not be considered when computing the likelihood Y_diag = torch.diag_embed(torch.ones( self.T_net, self.V_net)).transpose(0, 2).flatten() if self.coord: # gp_coord.shape # gp_coord[v,h,sr_i,lw_i,t] send = gp_coord[:, :, 0, :, :].expand(self.V_net, self.V_net, self.H_dim, self.lw_dim, self.T_net).transpose(0, 1) if self.directed: receive = gp_coord[:, :, 1, :, :].expand(self.V_net, self.V_net, self.H_dim, self.lw_dim, self.T_net) else: receive = gp_coord[:, :, 0, :, :].expand(self.V_net, self.V_net, self.H_dim, self.lw_dim, self.T_net) Y_linpred += ((send - receive)**2).sum(dim=2).rsqrt().transpose( 2, 3) if self.socpop: soc = gp_socpop[:, 0, :, :].expand(self.V_net, self.V_net, self.lw_dim, self.T_net).transpose(0, 1) pop = gp_socpop[:, 1, :, :].expand(self.V_net, self.V_net, self.lw_dim, self.T_net) Y_linpred += (soc + pop).transpose(2, 3) ### Sampling 0-1 links ### Y_probs = torch.sigmoid(Y_linpred[:, :, :, 0].flatten())[Y_diag == 0] Y_dist_link = dist.Bernoulli(Y_probs) Y_dist_link = Y_dist_link.expand_by( Y_probs.shape[:-Y_probs.dim()]).to_event(Y_probs.dim()) return pyro.sample("Y_link", Y_dist_link, obs=(self.Y_link.flatten())[Y_diag == 0]) ### Sampling weights ### if self.weighted: Y_weighted_id = (Y_diag == 0) & (self.Y.flatten() != 0) Y_loc = (Y_linpred[:, :, :, 1].flatten())[Y_weighted_id] Y_scale = self.sigma_Y.expand(self.V_net, self.V_net, self.T_net).flatten()[Y_weighted_id] Y_dist = dist.Normal(Y_loc, Y_scale) Y_dist = Y_dist.expand_by(Y_loc.shape[:-Y_loc.dim()]).to_event( Y_loc.dim()) pyro.sample("Y", Y_dist, obs=(self.Y.flatten())[Y_weighted_id]) # @autoname.scope(prefix="DMN") # generates error def guide(self): self.set_mode("guide") if self.coord: for v in range(self.V_net): for h in range(self.H_dim): for sr_i in range(self.sr_dim): for lw_i in range(self.lw_dim): pyro.sample( f'f_coord_v{v}_h{h}_sr{sr_i}_lw{lw_i}', dist.MultivariateNormal( self.gp.coord.loc[v, h, sr_i, lw_i, :], scale_tril=self.gp.coord.cov_tril[ v, h, sr_i, lw_i, :, :]).to_event( self.gp.coord.loc[v, h, sr_i, lw_i, :].dim() - 1)) if self.socpop: for v in range(self.V_net): for sp_i in range(2): for lw_i in range(self.lw_dim): pyro.sample( f'f_socpop_v{v}_{["soc","pop"][sp_i]}_lw{lw_i}', dist.MultivariateNormal( self.gp.socpop.loc[v, sp_i, lw_i, :], scale_tril=self.gp.socpop.cov_tril[v, sp_i, lw_i, :, :] ).to_event(self.gp.socpop.loc[v, sp_i, lw_i, :].dim() - 1)) def forward(self, Y_time_new, num_particles=30): r""" Computes something """ self.set_mode("guide") # Y_time_new=torch.arange(0,30,0.25) ## Generate coordinates for new times## if self.coord: Lff_coord = [] Kfs_coord = [] Kss_coord = [] for lw_i in range(self.lw_dim): Kff = eval('self.kernel.coord.' + ['link', 'weight'][lw_i])( self.Y_time).contiguous() Kff.view(-1)[::self.T_net + 1] += self.jitter # add jitter to the diagonal Lff_coord.append(Kff.cholesky()) Kfs_coord.append( eval('self.kernel.coord.' + ['link', 'weight'][lw_i])( self.Y_time, Y_time_new)) Kss_coord.append( eval('self.kernel.coord.' + ['link', 'weight'][lw_i])(Y_time_new).contiguous()) Kss_coord[lw_i].view( -1)[::Kss_coord[lw_i].shape[0] + 1] += self.jitter # add jitter to the diagonal gp_coord_loc_and_cov_new = torch.stack([ torch.stack([ torch.stack([ torch.stack([ torch.stack( pydmn.util.conditional( Xnew=Y_time_new, X=self.Y_time, kernel=None, f_loc=self.gp.coord.loc[v, h, sr_i, lw_i, :], f_scale_tril=self.gp.coord.cov_tril[ v, h, sr_i, lw_i, :, :], Lff=Lff_coord[lw_i], full_cov=True, whiten=self.whiten, jitter=self.jitter, Kfs=Kfs_coord[lw_i], Kss=Kss_coord[lw_i])) for lw_i in range(self.lw_dim) ]) for sr_i in range(self.sr_dim) ]) for h in range(self.H_dim) ]) for v in range(self.V_net) ]) gp_coord_loc_new = gp_coord_loc_and_cov_new[:, :, :, :, 0, 0, :] gp_coord_cov_new = gp_coord_loc_and_cov_new[:, :, :, :, 1, :, :] gp_coord_new_sample = torch.stack([ torch.stack([ torch.stack([ torch.stack([ dist.MultivariateNormal( gp_coord_loc_new[v, h, sr_i, lw_i, :] + self.gp.coord.loc_mean[v, h, sr_i, lw_i], gp_coord_cov_new[v, h, sr_i, lw_i, :, :]).expand([ num_particles ]).sample().transpose(0, 1) for lw_i in range(self.lw_dim) ]) for sr_i in range(self.sr_dim) ]) for h in range(self.H_dim) ]) for v in range(self.V_net) ]) ## Generate sociability and popularity for new times ## if self.socpop: Lff_socpop = [] Kfs_socpop = [] Kss_socpop = [] for lw_i in range(self.lw_dim): Kff = eval('self.kernel.socpop.' + ['link', 'weight'][lw_i])( self.Y_time).contiguous() Kff.view(-1)[::self.T_net + 1] += self.jitter # add jitter to the diagonal Lff_socpop.append(Kff.cholesky()) Kfs_socpop.append( eval('self.kernel.socpop.' + ['link', 'weight'][lw_i])( self.Y_time, Y_time_new)) Kss_socpop.append( eval('self.kernel.socpop.' + ['link', 'weight'][lw_i])(Y_time_new).contiguous()) Kss_socpop[lw_i].view( -1)[::Kss_socpop[lw_i].shape[0] + 1] += self.jitter # add jitter to the diagonal gp_socpop_loc_and_cov_new = torch.stack([ torch.stack([ torch.stack([ torch.stack( pydmn.util.conditional( Xnew=Y_time_new, X=self.Y_time, kernel=None, f_loc=self.gp.socpop.loc[v, sp_i, lw_i, :], f_scale_tril=self.gp.socpop.cov_tril[ v, sp_i, lw_i, :, :], Lff=Lff_socpop[lw_i], full_cov=True, whiten=self.whiten, jitter=self.jitter, Kfs=Kfs_socpop[lw_i], Kss=Kss_socpop[lw_i])) for lw_i in range(self.lw_dim) ]) for sp_i in range(2) ]) for v in range(self.V_net) ]) gp_socpop_loc_new = gp_socpop_loc_and_cov_new[:, :, :, 0, 0, :] gp_socpop_cov_new = gp_socpop_loc_and_cov_new[:, :, :, 1, :, :] gp_socpop_new_sample = torch.stack([ torch.stack([ torch.stack([ dist.MultivariateNormal( gp_socpop_loc_new[v, sp_i, lw_i, :] + self.gp.socpop.loc_mean[v, sp_i, lw_i], gp_socpop_cov_new[v, sp_i, lw_i, :, :]).expand( [num_particles]).sample().transpose(0, 1) for lw_i in range(self.lw_dim) ]) for sp_i in range(2) ]) for v in range(self.V_net) ]) ### Calculate Linear Predictors ### Y_linpred_new_sample = torch.zeros(self.V_net, self.V_net, Y_time_new.shape[0], self.lw_dim, num_particles) if self.coord: # gp_coord_new_sample.shape # gp_coord_new_sample[v,h,sr_i,lw_i,t,num_particles] send = gp_coord_new_sample[:, :, 0, :, :, :].expand( self.V_net, self.V_net, self.H_dim, self.lw_dim, Y_time_new.shape[0], num_particles).transpose(0, 1) if self.directed: receive = gp_coord_new_sample[:, :, 1, :, :, :].expand( self.V_net, self.V_net, self.H_dim, self.lw_dim, Y_time_new.shape[0], num_particles) else: receive = gp_coord_new_sample[:, :, 0, :, :, :].expand( self.V_net, self.V_net, self.H_dim, self.lw_dim, Y_time_new.shape[0], num_particles) Y_linpred_new_sample += ((send - receive)**2).sum( dim=2).rsqrt().transpose(2, 3) if self.socpop: soc = gp_socpop_new_sample[:, 0, :, :, :].expand( self.V_net, self.V_net, self.lw_dim, Y_time_new.shape[0], num_particles).transpose(0, 1) pop = gp_socpop_new_sample[:, 1, :, :, :].expand( self.V_net, self.V_net, self.lw_dim, Y_time_new.shape[0], num_particles) Y_linpred_new_sample += (soc + pop).transpose(2, 3) # Y_linpred_new_sample.mean(dim=Y_linpred_new_sample.dim()-1) return Y_linpred_new_sample def set_data(self, Y, Y_time, H_dim=3, X=None): """ Sets data for dmn models. """ assert (len(Y.shape) >= 3) and (len(Y.shape) <= 4) # square adjacence matrices assert Y.shape[0] == Y.shape[1] self.Y = Y self.V_net, self.T_net = self.Y.shape[0], self.Y.shape[2] self.Y_link = torch.where(Y != 0, torch.ones_like(Y), torch.zeros_like(Y)) self.K_net = Y.shape[3] if len(Y.shape) == 4 else 1 assert self.T_net == Y_time.shape[0] self.Y_time = Y_time assert int(H_dim) >= 1 self.H_dim = int(H_dim) self.X = X