class Linearlr(nn.Module): def __init__(self, in_features, out_features, rank, bias=True): super(Linearlr, self).__init__() self.in_features = in_features self.out_features = out_features print("rank {}, in_features {}, out_features {}".format( rank, in_features, out_features)) assert rank <= min(in_features, out_features) self.rank = rank self.weightA = Parameter(torch.Tensor(rank, in_features)) self.weightB = Parameter(torch.Tensor(out_features, rank)) if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.weightA.size(1)) self.weightA.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.uniform_(-stdv, stdv) stdv = 1. / math.sqrt(self.weightB.size(1)) self.weightB.data.uniform_(-stdv, stdv) def forward(self, input): weight = self.weightB.matmul(self.weightA) return F.linear(input, weight, self.bias) def extra_repr(self): return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None)
class WeightedAttention(nn.Module): """ Attention layer taking premises and hypotheses encoded by an RNN as input and computing the soft attention between their elements. The dot product of the encoded vectors in the premises and hypotheses is first computed. The softmax of the result is then used in a weighted sum of the vectors of the premises for each element of the hypotheses, and conversely for the elements of the premises. """ def __init__(self, embedding_dim): super(WeightedAttention, self).__init__() self.w = Parameter(torch.Tensor(embedding_dim, embedding_dim)) torch.nn.init.xavier_normal(self.w) def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): """ Args: premise_batch: A batch of sequences of vectors representing the premises in some NLI task. The batch is assumed to have the size (batch, sequences, vector_dim). premise_mask: A mask for the sequences in the premise batch, to ignore padding data in the sequences during the computation of the attention. hypothesis_batch: A batch of sequences of vectors representing the hypotheses in some NLI task. The batch is assumed to have the size (batch, sequences, vector_dim). hypothesis_mask: A mask for the sequences in the hypotheses batch, to ignore padding data in the sequences during the computation of the attention. Returns: attended_premises: The sequences of attention vectors for the premises in the input batch. attended_hypotheses: The sequences of attention vectors for the hypotheses in the input batch. """ # Dot product between premises and hypotheses in each sequence of # the batch. similarity_matrix = premise_batch.matmul( self.w.matmul(hypothesis_batch.transpose(2, 1).contiguous())) # Softmax attention weights. prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask) hyp_prem_attn = masked_softmax( similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Weighted sums of the hypotheses for the the premises attention, # and vice-versa for the attention of the hypotheses. attended_premises = weighted_sum(hypothesis_batch, prem_hyp_attn, premise_mask) attended_hypotheses = weighted_sum(premise_batch, hyp_prem_attn, hypothesis_mask) return attended_premises, attended_hypotheses, similarity_matrix
class Factorize(nn.Module): def __init__(self, factors): super(Factorize, self).__init__() self.A = Parameter(torch.randn(l_miast, factors)) self.B = Parameter(torch.randn(factors, l_miesiecy)) self.global_bias = Parameter(torch.randn(1)) self.bias_miast = Parameter(torch.randn(l_miast)) def forward(self): output = self.A.matmul(self.B) + self.global_bias output = output.transpose(0, 1) for i in range(l_miesiecy): output[i] = output[i] + self.bias_miast output = output.transpose(0, 1) return output
class BilinearMLPAbstractPredictor(MLPAbstractPredictor): """ Similar to the MLP Abstract Predictor but applies a bilinear transform instead of addition. """ def __init__(self, data, config, predictor_layers, uses_raw_response): super(BilinearMLPAbstractPredictor, self).__init__(data, config, predictor_layers, uses_raw_response) # Number of bilinear transformations == the dimension of the layer at which the merge is performed # Initialize weights close to identity self.bilinear_weights = Parameter( 1 / 100 * torch.randn((self.merge_dim, self.merge_dim, self.merge_dim)) + torch.cat([torch.eye(self.merge_dim)[None, :, :]] * self.merge_dim, dim=0)) self.bilinear_offsets = Parameter(1 / 100 * torch.randn( (self.merge_dim))) def single_forward_pass(self, h_drug_1, h_drug_2, cell_lines): # Apply before merge MLP h_1 = self.before_merge_mlp([h_drug_1, cell_lines])[0] h_2 = self.before_merge_mlp([h_drug_2, cell_lines])[0] # compute <W.h_1, W.h_2> = h_1.T . W.T.W . h_2 h_1 = self.bilinear_weights.matmul(h_1.T).T h_2 = self.bilinear_weights.matmul(h_2.T).T # "Transpose" h_1 h_1 = h_1.permute(0, 2, 1) # Multiplication h_1_scal_h_2 = (h_1 * h_2).sum(1) # Add offset h_1_scal_h_2 += self.bilinear_offsets comb = self.after_merge_mlp([h_1_scal_h_2, cell_lines])[0] return ( comb, self.transform_single_drug(h_drug_1, cell_lines), self.transform_single_drug(h_drug_2, cell_lines), )
class Linearsp_v2(nn.Module): def __init__(self, in_features, out_features, rank, bias=True): super(Linearsp_v2, self).__init__() self.in_features = in_features self.out_features = out_features # print("rank {}, in_features {}, out_features {}".format(rank, in_features, out_features)) assert rank <= min(in_features, out_features) self.rank = rank self.weightA = Parameter(torch.zeros(rank, in_features)) self.weightB = Parameter(torch.zeros(out_features, rank)) self.weightC = Parameter(torch.zeros(out_features, in_features)) self.eye = torch.eye(rank) self.register_buffer('eye_const', self.eye) if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.weightA.size(1)) if self.bias is not None: self.bias.data.uniform_(-stdv, stdv) def forward(self, input): if self.rank == self.in_features: weight = self.weightB.matmul(self.weightA + self.eye_const) + self.weightC else: weight = (self.weightB + self.eye_const).matmul( self.weightA) + self.weightC return F.linear(input, weight, self.bias) def extra_repr(self): return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None)
class BaseRNNCell(nn.Module): def __init__(self, input_size, hidden_size, bias=False, nonlinearity="tanh", hidden_min_abs=0, hidden_max_abs=None, hidden_init=None, recurrent_init=None, gradient_clip=5): super(BaseRNNCell, self).__init__() self.hidden_max_abs = hidden_max_abs self.hidden_min_abs = hidden_min_abs self.input_size = input_size self.hidden_size = hidden_size self.bias = bias self.nonlinearity = nonlinearity self.hidden_init = hidden_init self.recurrent_init = recurrent_init if self.nonlinearity == "tanh": self.activation = F.tanh elif self.nonlinearity == "relu": self.activation = F.relu elif self.nonlinearity == "sigmoid": self.activation = F.sigmoid elif self.nonlinearity == "log": self.activation = torch.log elif self.nonlinearity == "sin": self.activation = torch.sin else: raise RuntimeError("Unknown nonlinearity: {}".format( self.nonlinearity)) self.weight_ih = Parameter(torch.eye(hidden_size, input_size)) self.weight_hh = Parameter(torch.Tensor(hidden_size, 20).uniform_()) self.weight_hh1 = Parameter(torch.eye(input_size, hidden_size)) if bias: self.bias_ih = Parameter(torch.randn(hidden_size)) else: self.register_parameter('bias_ih', None) # self.reset_parameters() def reset_parameters(self): stdv = 1.0 / math.sqrt(self.hidden_size) for weight in self.parameters(): weight.data.uniform_(-stdv, stdv) # def reset_parameters(self): # for name, weight in self.named_parameters(): # if "bias" in name: # weight.data.zero_() # elif "weight_hh" in name: # if self.recurrent_init is None: # nn.init.constant_(weight, 1) # else: # self.recurrent_init(weight) # elif "weight_ih" in name: # if self.hidden_init is None: # nn.init.normal_(weight, 0, 0.01) # else: # self.hidden_init(weight) # else: # weight.data.normal_(0, 0.01) # # weight.data.uniform_(-stdv, stdv) # self.check_bounds() def check_bounds(self): if self.hidden_min_abs: abs_kernel = torch.abs( self.weight_hh.data).clamp_(min=self.hidden_min_abs) self.weight_hh.data = self.weight_hh.mul( torch.sign(self.weight_hh.data), abs_kernel) if self.hidden_max_abs: self.weight_hh.data = self.weight_hh.clamp( max=self.hidden_max_abs, min=-self.hidden_max_abs) def forward(self, input, hx): # x = F.linear(input, self.weight_ih, self.bias_ih) + torch.matmul(hx, self.weight_hh.matmul(self.weight_hh1)) # return self.talor(x) return self.activation( F.linear(input, self.weight_ih, self.bias_ih) + torch.matmul(hx, self.weight_ih.matmul(self.weight_hh1))) def talor(self, x): return (x - 1) - (x - 1) * (x - 1) / 2 + (x - 1) * (x - 1) * (x - 1) / 3
class DaleConstrainedIntegrator(Module): def __init__(self, args_dict): super(DaleConstrainedIntegrator, self).__init__() self.is_W_parametrized = True self.is_dale_constrained = True for k, v in args_dict.items(): setattr(self, k, v) if self.saturations != [0, 1e8]: logging.error( 'DaleConstrainedIntegrators should be ReLU, not saturated as {}' .format(self.saturations)) raise RuntimeError std = 1. / sqrt(self.n) # Dale specific parameters # self.inhib_proportion = .25 # Fraction of neurons that will be inhibitory, should now be a parameter # Don't add that yet... # self.inhib_fan_out = 20 # Number of allowed out-going connections for inhibitory neurons # self.excit_fan_out = 20 # Number of allowed out-going connections for excitatory neurons self.encoders = ParameterList([ Parameter(tch.zeros(self.n).normal_(0, std), requires_grad=False) for _ in range(self.n_channels) ]) self.decoders = ParameterList([ Parameter(tch.zeros(self.n).normal_(0, std), requires_grad=False) for _ in range(self.n_channels) ]) if self.init_vectors_type == 'random': pass elif self.init_vectors_type == 'orthonormal': logging.info('Orthogonalizing encoders and decoders') plop = tch.zeros(self.n, 2 * self.n_channels) for idx, item in enumerate(self.encoders): plop[:, idx] = item.data for idx, item in enumerate(self.decoders): plop[:, len(self.encoders) + idx] = item.data plop = orth(plop) for idx, item in enumerate(self.encoders): item.data = plop[:, idx] for idx, item in enumerate(self.decoders): item.data = plop[:, len(self.encoders) + idx] if self.n_channels == 1: # Force normalizations self.encoders[0].data = self.encoders[0].data / tch.sqrt( (self.encoders[0].data**2).sum()) self.decoders[0].data = self.decoders[0].data / tch.sqrt( (self.decoders[0].data**2).sum()) # Align the encoder / decoder self.decoders[0].data = ( (1. - self.init_vectors_overlap) * self.decoders[0].data + self.init_vectors_overlap * self.encoders[0].data) # Rescale the io vectors self.decoders[0].data = self.init_vectors_scales[ 0] * self.decoders[0].data / tch.sqrt( (self.decoders[0].data**2).sum()) self.encoders[ 0].data = self.encoders[0].data * self.init_vectors_scales[1] self.n_inhib = int(self.n * self.inhib_proportion) self.n_excit = self.n - self.n_inhib self.synapse_signs = Parameter( tch.Tensor([1. for _ in range(self.n_excit)] + [-1. for _ in range(self.n_inhib)]), requires_grad=False).float() self.W = Parameter(tch.zeros(self.n, self.n).normal_(0, std), requires_grad=True) eigs, _ = tch.eig(self.W, eigenvectors=False) spectral_rad = tch.sqrt((eigs**2).sum(dim=1).max()).item() assert spectral_rad != 0 self.W.data = self.init_radius * self.W.data / spectral_rad if self.init_radius != 0: logging.error( 'DaleConstrainedIntegrators should be initialized with W=0 for now at least' ) raise RuntimeError assert (self.W.data == 0.).all() self.device = tch.device(self.device_name) self.to(self.device) os.makedirs(self.save_folder, exist_ok=True) def step(self, state, inputs, mask, keep_currents=False): external_current = self.encoders[0] * inputs[0].view(-1, 1) for i in range(1, self.n_channels): external_current = external_current + self.encoders[i] * inputs[ i].view(-1, 1) if keep_currents: cur = (state + mask * external_current).matmul((self.W.matmul( tch.diag(self.synapse_signs))).t()).detach().clone() state = mask * tch.clamp( (state + mask * external_current).matmul((self.W.matmul( tch.diag(self.synapse_signs))).t()), *self.saturations) # The .t() above are here for batch operation, but W is really the coupling matrix with correct convention # W_ij = weight from j to i outs = [(self.decoders[i] * state).sum(-1) for i in range(self.n_channels)] if keep_currents: return outs, state, cur else: return outs, state def forward(self, inputs, state, mask, keep_currents=False): T = len(inputs[0][1]) inputs_unbinded = [inputs[i].unbind(1) for i in range(self.n_channels)] outputs = [ tch.jit.annotate(List[Tensor], []) for _ in range(self.n_channels) ] if keep_currents: currents = tch.jit.annotate(List[Tensor], []) for t in range(T): inp = [inputs_unbinded[i][t] for i in range(self.n_channels)] if keep_currents: outs, state, cur = self.step(state, inp, mask, keep_currents=True) currents += [cur.detach()] else: outs, state = self.step(state, inp, mask, keep_currents=False) for i in range(self.n_channels): outputs[i] = outputs[i] + [outs[i]] for i in range(self.n_channels): outputs[i] = tch.stack(outputs[i], dim=1) if keep_currents: return outputs, tch.stack(currents, dim=1) else: return outputs def integrate(self, X, keep_currents=False, mask=None): # Expect X to be [np.array(bs, T) for c in range(n_channels)] if type(X) is not list: logging.error('integrate expects a list as X input, not {}'.format( type(X))) raise RuntimeError if len(X) != self.n_channels: logging.error( 'integrate expects same number of input signals as channels, not {} and {}' .format(len(X), self.n_channels)) raise RuntimeError if not (self.W >= 0.).all(): logging.error( 'Found non fully positive W in integrate, something went wrong in optimization' ) raise RuntimeError # Make the input tch tensor, or do nothing if they already are (e.f. when calling integrate twice on same X) for c in range(self.n_channels): try: X[c] = tch.from_numpy(X[c]).to(self.device) except TypeError: pass # mask is not used for this project, but could be useful for implementing "ablations" # by forcing a subset of neurons to have 0 activation at all times tmp = tch.ones(self.n) if mask is not None: assert type(mask) is ndarray tmp = tch.from_numpy(mask).float() mask = tmp.to(self.device) init_state = tch.zeros(self.n).to(self.device) return self.forward(X, init_state, mask, keep_currents=keep_currents)