示例#1
0
class Linearlr(nn.Module):
    def __init__(self, in_features, out_features, rank, bias=True):
        super(Linearlr, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        print("rank {}, in_features {}, out_features {}".format(
            rank, in_features, out_features))
        assert rank <= min(in_features, out_features)
        self.rank = rank
        self.weightA = Parameter(torch.Tensor(rank, in_features))
        self.weightB = Parameter(torch.Tensor(out_features, rank))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weightA.size(1))
        self.weightA.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

        stdv = 1. / math.sqrt(self.weightB.size(1))
        self.weightB.data.uniform_(-stdv, stdv)

    def forward(self, input):
        weight = self.weightB.matmul(self.weightA)
        return F.linear(input, weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None)
示例#2
0
class WeightedAttention(nn.Module):
    """
    Attention layer taking premises and hypotheses encoded by an RNN as input
    and computing the soft attention between their elements.

    The dot product of the encoded vectors in the premises and hypotheses is
    first computed. The softmax of the result is then used in a weighted sum
    of the vectors of the premises for each element of the hypotheses, and
    conversely for the elements of the premises.
    """
    def __init__(self, embedding_dim):
        super(WeightedAttention, self).__init__()
        self.w = Parameter(torch.Tensor(embedding_dim, embedding_dim))
        torch.nn.init.xavier_normal(self.w)

    def forward(self, premise_batch, premise_mask, hypothesis_batch,
                hypothesis_mask):
        """
        Args:
            premise_batch: A batch of sequences of vectors representing the
                premises in some NLI task. The batch is assumed to have the
                size (batch, sequences, vector_dim).
            premise_mask: A mask for the sequences in the premise batch, to
                ignore padding data in the sequences during the computation of
                the attention.
            hypothesis_batch: A batch of sequences of vectors representing the
                hypotheses in some NLI task. The batch is assumed to have the
                size (batch, sequences, vector_dim).
            hypothesis_mask: A mask for the sequences in the hypotheses batch,
                to ignore padding data in the sequences during the computation
                of the attention.

        Returns:
            attended_premises: The sequences of attention vectors for the
                premises in the input batch.
            attended_hypotheses: The sequences of attention vectors for the
                hypotheses in the input batch.
        """
        # Dot product between premises and hypotheses in each sequence of
        # the batch.
        similarity_matrix = premise_batch.matmul(
            self.w.matmul(hypothesis_batch.transpose(2, 1).contiguous()))

        # Softmax attention weights.
        prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask)
        hyp_prem_attn = masked_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), premise_mask)

        # Weighted sums of the hypotheses for the the premises attention,
        # and vice-versa for the attention of the hypotheses.
        attended_premises = weighted_sum(hypothesis_batch, prem_hyp_attn,
                                         premise_mask)
        attended_hypotheses = weighted_sum(premise_batch, hyp_prem_attn,
                                           hypothesis_mask)

        return attended_premises, attended_hypotheses, similarity_matrix
示例#3
0
class Factorize(nn.Module):
    def __init__(self, factors):
        super(Factorize, self).__init__()
        self.A = Parameter(torch.randn(l_miast, factors))
        self.B = Parameter(torch.randn(factors, l_miesiecy))
        self.global_bias = Parameter(torch.randn(1))
        self.bias_miast = Parameter(torch.randn(l_miast))

    def forward(self):
        output = self.A.matmul(self.B) + self.global_bias
        output = output.transpose(0, 1)
        for i in range(l_miesiecy):
            output[i] = output[i] + self.bias_miast
        output = output.transpose(0, 1)
        return output
示例#4
0
class BilinearMLPAbstractPredictor(MLPAbstractPredictor):
    """
    Similar to the MLP Abstract Predictor but applies a bilinear transform instead of addition.
    """
    def __init__(self, data, config, predictor_layers, uses_raw_response):
        super(BilinearMLPAbstractPredictor,
              self).__init__(data, config, predictor_layers, uses_raw_response)

        # Number of bilinear transformations == the dimension of the layer at which the merge is performed
        # Initialize weights close to identity
        self.bilinear_weights = Parameter(
            1 / 100 *
            torch.randn((self.merge_dim, self.merge_dim, self.merge_dim)) +
            torch.cat([torch.eye(self.merge_dim)[None, :, :]] * self.merge_dim,
                      dim=0))
        self.bilinear_offsets = Parameter(1 / 100 * torch.randn(
            (self.merge_dim)))

    def single_forward_pass(self, h_drug_1, h_drug_2, cell_lines):

        # Apply before merge MLP
        h_1 = self.before_merge_mlp([h_drug_1, cell_lines])[0]
        h_2 = self.before_merge_mlp([h_drug_2, cell_lines])[0]

        # compute <W.h_1, W.h_2> = h_1.T . W.T.W . h_2
        h_1 = self.bilinear_weights.matmul(h_1.T).T
        h_2 = self.bilinear_weights.matmul(h_2.T).T

        # "Transpose" h_1
        h_1 = h_1.permute(0, 2, 1)

        # Multiplication
        h_1_scal_h_2 = (h_1 * h_2).sum(1)

        # Add offset
        h_1_scal_h_2 += self.bilinear_offsets

        comb = self.after_merge_mlp([h_1_scal_h_2, cell_lines])[0]

        return (
            comb,
            self.transform_single_drug(h_drug_1, cell_lines),
            self.transform_single_drug(h_drug_2, cell_lines),
        )
示例#5
0
class Linearsp_v2(nn.Module):
    def __init__(self, in_features, out_features, rank, bias=True):
        super(Linearsp_v2, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        # print("rank {}, in_features {}, out_features {}".format(rank, in_features, out_features))
        assert rank <= min(in_features, out_features)
        self.rank = rank
        self.weightA = Parameter(torch.zeros(rank, in_features))
        self.weightB = Parameter(torch.zeros(out_features, rank))
        self.weightC = Parameter(torch.zeros(out_features, in_features))

        self.eye = torch.eye(rank)
        self.register_buffer('eye_const', self.eye)
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weightA.size(1))
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        if self.rank == self.in_features:
            weight = self.weightB.matmul(self.weightA +
                                         self.eye_const) + self.weightC
        else:
            weight = (self.weightB + self.eye_const).matmul(
                self.weightA) + self.weightC
        return F.linear(input, weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None)
示例#6
0
class BaseRNNCell(nn.Module):
    def __init__(self,
                 input_size,
                 hidden_size,
                 bias=False,
                 nonlinearity="tanh",
                 hidden_min_abs=0,
                 hidden_max_abs=None,
                 hidden_init=None,
                 recurrent_init=None,
                 gradient_clip=5):
        super(BaseRNNCell, self).__init__()
        self.hidden_max_abs = hidden_max_abs
        self.hidden_min_abs = hidden_min_abs
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.nonlinearity = nonlinearity
        self.hidden_init = hidden_init
        self.recurrent_init = recurrent_init
        if self.nonlinearity == "tanh":
            self.activation = F.tanh
        elif self.nonlinearity == "relu":
            self.activation = F.relu
        elif self.nonlinearity == "sigmoid":
            self.activation = F.sigmoid
        elif self.nonlinearity == "log":
            self.activation = torch.log
        elif self.nonlinearity == "sin":
            self.activation = torch.sin
        else:
            raise RuntimeError("Unknown nonlinearity: {}".format(
                self.nonlinearity))

        self.weight_ih = Parameter(torch.eye(hidden_size, input_size))
        self.weight_hh = Parameter(torch.Tensor(hidden_size, 20).uniform_())
        self.weight_hh1 = Parameter(torch.eye(input_size, hidden_size))
        if bias:
            self.bias_ih = Parameter(torch.randn(hidden_size))
        else:
            self.register_parameter('bias_ih', None)
        # self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    # def reset_parameters(self):
    #     for name, weight in self.named_parameters():
    #         if "bias" in name:
    #             weight.data.zero_()
    #         elif "weight_hh" in name:
    #             if self.recurrent_init is None:
    #                 nn.init.constant_(weight, 1)
    #             else:
    #                 self.recurrent_init(weight)
    #         elif "weight_ih" in name:
    #             if self.hidden_init is None:
    #                 nn.init.normal_(weight, 0, 0.01)
    #             else:
    #                 self.hidden_init(weight)
    #         else:
    #             weight.data.normal_(0, 0.01)
    #             # weight.data.uniform_(-stdv, stdv)
    #     self.check_bounds()

    def check_bounds(self):
        if self.hidden_min_abs:
            abs_kernel = torch.abs(
                self.weight_hh.data).clamp_(min=self.hidden_min_abs)
            self.weight_hh.data = self.weight_hh.mul(
                torch.sign(self.weight_hh.data), abs_kernel)
        if self.hidden_max_abs:
            self.weight_hh.data = self.weight_hh.clamp(
                max=self.hidden_max_abs, min=-self.hidden_max_abs)

    def forward(self, input, hx):
        # x = F.linear(input, self.weight_ih, self.bias_ih) + torch.matmul(hx, self.weight_hh.matmul(self.weight_hh1))
        # return self.talor(x)
        return self.activation(
            F.linear(input, self.weight_ih, self.bias_ih) +
            torch.matmul(hx, self.weight_ih.matmul(self.weight_hh1)))

    def talor(self, x):
        return (x -
                1) - (x - 1) * (x - 1) / 2 + (x - 1) * (x - 1) * (x - 1) / 3
示例#7
0
class DaleConstrainedIntegrator(Module):
    def __init__(self, args_dict):
        super(DaleConstrainedIntegrator, self).__init__()
        self.is_W_parametrized = True
        self.is_dale_constrained = True

        for k, v in args_dict.items():
            setattr(self, k, v)
        if self.saturations != [0, 1e8]:
            logging.error(
                'DaleConstrainedIntegrators should be ReLU, not saturated as {}'
                .format(self.saturations))
            raise RuntimeError

        std = 1. / sqrt(self.n)

        # Dale specific parameters
        # self.inhib_proportion = .25 # Fraction of neurons that will be inhibitory, should now be a parameter
        # Don't add that yet...
        # self.inhib_fan_out = 20 # Number of allowed out-going connections for inhibitory neurons
        # self.excit_fan_out = 20 # Number of allowed out-going connections for excitatory neurons

        self.encoders = ParameterList([
            Parameter(tch.zeros(self.n).normal_(0, std), requires_grad=False)
            for _ in range(self.n_channels)
        ])
        self.decoders = ParameterList([
            Parameter(tch.zeros(self.n).normal_(0, std), requires_grad=False)
            for _ in range(self.n_channels)
        ])
        if self.init_vectors_type == 'random':
            pass
        elif self.init_vectors_type == 'orthonormal':
            logging.info('Orthogonalizing encoders and decoders')
            plop = tch.zeros(self.n, 2 * self.n_channels)
            for idx, item in enumerate(self.encoders):
                plop[:, idx] = item.data
            for idx, item in enumerate(self.decoders):
                plop[:, len(self.encoders) + idx] = item.data
            plop = orth(plop)
            for idx, item in enumerate(self.encoders):
                item.data = plop[:, idx]
            for idx, item in enumerate(self.decoders):
                item.data = plop[:, len(self.encoders) + idx]

        if self.n_channels == 1:
            # Force normalizations
            self.encoders[0].data = self.encoders[0].data / tch.sqrt(
                (self.encoders[0].data**2).sum())
            self.decoders[0].data = self.decoders[0].data / tch.sqrt(
                (self.decoders[0].data**2).sum())
            # Align the encoder / decoder
            self.decoders[0].data = (
                (1. - self.init_vectors_overlap) * self.decoders[0].data +
                self.init_vectors_overlap * self.encoders[0].data)
            # Rescale the io vectors
            self.decoders[0].data = self.init_vectors_scales[
                0] * self.decoders[0].data / tch.sqrt(
                    (self.decoders[0].data**2).sum())
            self.encoders[
                0].data = self.encoders[0].data * self.init_vectors_scales[1]

        self.n_inhib = int(self.n * self.inhib_proportion)
        self.n_excit = self.n - self.n_inhib
        self.synapse_signs = Parameter(
            tch.Tensor([1. for _ in range(self.n_excit)] +
                       [-1. for _ in range(self.n_inhib)]),
            requires_grad=False).float()
        self.W = Parameter(tch.zeros(self.n, self.n).normal_(0, std),
                           requires_grad=True)
        eigs, _ = tch.eig(self.W, eigenvectors=False)
        spectral_rad = tch.sqrt((eigs**2).sum(dim=1).max()).item()
        assert spectral_rad != 0
        self.W.data = self.init_radius * self.W.data / spectral_rad
        if self.init_radius != 0:
            logging.error(
                'DaleConstrainedIntegrators should be initialized with W=0 for now at least'
            )
            raise RuntimeError
        assert (self.W.data == 0.).all()
        self.device = tch.device(self.device_name)
        self.to(self.device)
        os.makedirs(self.save_folder, exist_ok=True)

    def step(self, state, inputs, mask, keep_currents=False):
        external_current = self.encoders[0] * inputs[0].view(-1, 1)
        for i in range(1, self.n_channels):
            external_current = external_current + self.encoders[i] * inputs[
                i].view(-1, 1)
        if keep_currents:
            cur = (state + mask * external_current).matmul((self.W.matmul(
                tch.diag(self.synapse_signs))).t()).detach().clone()
        state = mask * tch.clamp(
            (state + mask * external_current).matmul((self.W.matmul(
                tch.diag(self.synapse_signs))).t()), *self.saturations)
        # The .t() above are here for batch operation, but W is really the coupling matrix with correct convention
        # W_ij = weight from j to i
        outs = [(self.decoders[i] * state).sum(-1)
                for i in range(self.n_channels)]
        if keep_currents:
            return outs, state, cur
        else:
            return outs, state

    def forward(self, inputs, state, mask, keep_currents=False):
        T = len(inputs[0][1])
        inputs_unbinded = [inputs[i].unbind(1) for i in range(self.n_channels)]
        outputs = [
            tch.jit.annotate(List[Tensor], []) for _ in range(self.n_channels)
        ]
        if keep_currents:
            currents = tch.jit.annotate(List[Tensor], [])
        for t in range(T):
            inp = [inputs_unbinded[i][t] for i in range(self.n_channels)]
            if keep_currents:
                outs, state, cur = self.step(state,
                                             inp,
                                             mask,
                                             keep_currents=True)
                currents += [cur.detach()]
            else:
                outs, state = self.step(state, inp, mask, keep_currents=False)
            for i in range(self.n_channels):
                outputs[i] = outputs[i] + [outs[i]]
        for i in range(self.n_channels):
            outputs[i] = tch.stack(outputs[i], dim=1)
        if keep_currents:
            return outputs, tch.stack(currents, dim=1)
        else:
            return outputs

    def integrate(self, X, keep_currents=False, mask=None):
        # Expect X to be [np.array(bs, T) for c in range(n_channels)]
        if type(X) is not list:
            logging.error('integrate expects a list as X input, not {}'.format(
                type(X)))
            raise RuntimeError
        if len(X) != self.n_channels:
            logging.error(
                'integrate expects same number of input signals as channels, not {} and {}'
                .format(len(X), self.n_channels))
            raise RuntimeError
        if not (self.W >= 0.).all():
            logging.error(
                'Found non fully positive W in integrate, something went wrong in optimization'
            )
            raise RuntimeError

        # Make the input tch tensor, or do nothing if they already are (e.f. when calling integrate twice on same X)
        for c in range(self.n_channels):
            try:
                X[c] = tch.from_numpy(X[c]).to(self.device)
            except TypeError:
                pass

        # mask is not used for this project, but could be useful for implementing "ablations"
        # by forcing a subset of neurons to have 0 activation at all times
        tmp = tch.ones(self.n)
        if mask is not None:
            assert type(mask) is ndarray
            tmp = tch.from_numpy(mask).float()
        mask = tmp.to(self.device)

        init_state = tch.zeros(self.n).to(self.device)
        return self.forward(X, init_state, mask, keep_currents=keep_currents)