Пример #1
0
    def forward(self, input):
        """
        Args:
            - input: of shape (seq_len, batch_size).
        Returns:
            - result: of shape (seq_len, batch_size, emb_dim_size)
        """
        #Wrapping as parameter is important to convert it as leaf node
        embs = Parameter(self.to_embeddings(input).to(self.weight.device))

        if self.requires_emb_grad:
            # registers hook to track gradients of the embedded sequences
            embs.register_hook(self.save_grad)
            # embs.register_hook(print)

        return embs
Пример #2
0
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.vker1 = Parameter(kernel1, requires_grad=True)
        self.vker1.register_hook(print)
        self.vker2 = Parameter(kernel2, requires_grad=True)
        self.vker2.register_hook(print)
        self.vMat = Parameter(matrix, requires_grad=True)
        self.vMat.register_hook(print)
        self.vVec = Parameter(bias, requires_grad=True)
        self.vVec.register_hook(print)

    def forward(self, x):
        print(x)
        resConv = F.conv2d(x, self.vker1)
        resConv.register_hook(print)
        print(resConv)
        resMax = F.max_pool2d(resConv, 2)
        resMax.register_hook(print)
        print(resMax)
        resRelu = F.relu(resMax)
        resRelu.register_hook(print)
        print(resRelu)
        resConv2 = F.conv2d(resRelu, self.vker2)
        resConv2.register_hook(print)
        print(resConv2)
        resRelu1 = F.relu(resConv2)
        resRelu1.register_hook(print)
        print(resRelu1)
        resMMat = F.linear(resRelu1.view(1, 12), self.vMat, self.vVec)
        resMMat.register_hook(print)
        print(resMMat)
        resLSM = F.log_softmax(resMMat, dim = 1)
        resLSM.register_hook(print)
        print(resLSM)
        return resLSM

    def print(self):
       print("==========================================")
       print("model param:")
       print(self.vker1)
       print(self.vker2)
       print(self.vMat)
       print(self.vVec)
Пример #3
0
class down(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(down, self).__init__()

        self.in_features = in_features
        self.out_features = out_features

        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            #Remember, here bias is an offset in DOMAIN, not codomain
            self.bias = Parameter(torch.Tensor(in_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

        def bH(grad):
            with torch.no_grad():
                return grad - self._collapse(grad, bias=False)

        self.weight.register_hook(bH)

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)
        #Cheap procedure to get a unitary matrix, and preserve only part of it.
        a = torch.randn(self.out_features, self.in_features)
        _, _, start = svd(a.numpy())
        self.weight = Parameter(torch.Tensor(start[:self.out_features, :]))

    def rescale(self):
        """Resets rows to 1-norm."""
        n = torch.sqrt(torch.sum(self.weight.data * self.weight.data,
                                 1)).view(self.out_features, 1)
        #        print(torch.max(n))
        self.weight.data = self.weight.data / n


#        n = torch.sqrt(torch.sum(self.weight.data*self.weight.data, 1)).view(self.out_features, 1)
#        print(torch.max(n))

    def _forward(self, x, bias=True):
        if type(self.bias) != type(None) and bias:
            return F.linear(x - self.bias, self.weight, None)
        else:
            return F.linear(x, self.weight, None)

    def _pushback(self, y, bias=True):
        if bias:
            return F.linear(y, self.weight.t(), self.bias)
        else:
            return F.linear(y, self.weight.t(), None)

    def forward(self, x, bias=True):
        #Option to ignore offset useful for gradient reset.
        self._fix()
        return self._forward(x, bias)

    def pushback(self, y, bias=True):
        self._fix()
        return self._pushback(y, bias)

    def _collapse(self, x, bias=True):
        """Stays in codomain, but goes down to this linear space."""
        return self._pushback(self._forward(x, bias), bias)

    def collapse(self, x, bias=True):
        self._fix()
        return self._collapse(x, bias)

    def _fix(self):
        while self.badness() > 1e-4:
            with torch.no_grad():
                self.reOrth()

    def reOrth(self):
        #This is a way to push the vectors away from each other.
        #First order reorthogonalizaiton
        self.weight = Parameter(
            self.weight +
            (self.weight - self._collapse(self.weight, bias=False)) /
            self.out_features)
        self.rescale()

    def badness(self):
        """Measure of non-orthogonality"""
        if self.weight.data.is_cuda:
            y = torch.matmul(self.weight, self.weight.t()) - torch.eye(
                self.out_features).cuda()
        else:
            y = torch.matmul(self.weight, self.weight.t()) - torch.eye(
                self.out_features)
        return torch.sum(y * y)
Пример #4
0
class EntNet(nn.Module):
    def __init__(self, vocab_size, memory_slots=20, emb_size=100, bow_encoding=False, max_sentence_length=50, max_query_length=50):
        super(EntNet, self).__init__()
        self.bow_encoding = bow_encoding

        self.memory_slots = memory_slots
        self.vocab_size = vocab_size
        self.emb_size = emb_size

        # Initialize embedding
        emb_init_weight = nn.init.normal_(torch.empty(vocab_size, emb_size), mean=0, std=0.1)
        emb_init_weight[0].zero_()  # 0 val for padding symbol
        self.embedding = nn.Embedding(vocab_size, emb_size, _weight=emb_init_weight)

        # Make padding synbol gradient zero
        def emb_hook(grad):
            grad[0].zero_()
            return grad
        self.embedding.weight.register_hook(lambda grad: emb_hook(grad))

        # Initialize Sentence Encoder Parameters with all ones (BoW)
        init_story_weight = torch.ones(max_sentence_length, emb_size)
        init_query_weight = torch.ones(max_query_length, emb_size)
        if self.bow_encoding:
            self.mult_mask = nn.Embedding.from_pretrained(init_story_weight, freeze=True)
            self.mult_mask_query = nn.Embedding.from_pretrained(init_query_weight, freeze=True)
        else:
            self.mult_mask = nn.Embedding(max_sentence_length, emb_size, _weight=init_story_weight)
            self.mult_mask_query = nn.Embedding(max_query_length, emb_size, _weight=init_query_weight)

        # Initialize the PRelu activation
        self.activation = nn.PReLU(num_parameters=emb_size, init=1.0)

        self.memory_net = DynamicMemory(hidden_size=emb_size, memory_slots=memory_slots, activation=self.activation)

        # Output module parameters
        # Initialize R
        R_init_weight = nn.init.normal_(torch.empty(emb_size, vocab_size), mean=0, std=0.1)
        # Zeroout the 0th column representing output embedding for pad
        R_init_weight[:, 0].zero_()
        self.R = Parameter(R_init_weight)

        # Make gradient corresponding to padding symbol 0
        def R_hook(grad):
            grad[:, 0].zero_()
            return grad
        self.R.register_hook(lambda grad: R_hook(grad))

        self.H = Parameter(nn.init.normal_(torch.empty(emb_size, emb_size), mean=0, std=0.1))


    def encode_sent(self, batch_sent, query=False):
        """
        Encode a batch of sentences.
        batch_sent: B x L
        """
        # Embed Sentence
        sent_tensor = self.embedding(batch_sent)  # B x L x E
        # Get non-zero indices (Will only select indices >= 1)
        mask_tensor = torch.unsqueeze(torch.ge(batch_sent, 1), dim=2).cuda()  # B x L x 1
        # Cast the mask_tensor
        mask_tensor = mask_tensor.type(dtype=torch.cuda.FloatTensor)
        # Mask out padded symbols
        sent_tensor = sent_tensor * mask_tensor # B x L x E

        # Get embedding for multiplicative masks
        _, sentence_length = list(batch_sent.size())
        idx_tensor = torch.arange(sentence_length).cuda()
        if query:
            mult_mask_tensor = self.mult_mask_query(idx_tensor)
        else:
            mult_mask_tensor = self.mult_mask(idx_tensor) # L x E

        batch_sent_emb = sent_tensor * mult_mask_tensor # B x L x E
        batch_sent_emb = torch.sum(batch_sent_emb, 1) # B x E

        return batch_sent_emb


    def encode_story(self, batch_story):
        """Encode a batch of stories.
        batch_story: B x T x L (T is # of sentences and L is max length of each sentence)

        Return:-
        enc_stories: list of B x H tensors with length T.
        mask_stories: list of length T with tensors of size (B,)
        """
        batch_story = torch.transpose(batch_story, 0, 1)  # T x B x L

        # Split stories along sentences
        seq_batch_sent = torch.unbind(batch_story, dim=0)  # Tuple of B x L tensors
        enc_stories = []
        mask_stories = []
        for batch_sent in seq_batch_sent:
            enc_stories.append(self.encode_sent(batch_sent))
            # Check if all the symbols are pad or not
            mask_at_t = torch.ge(torch.sum(batch_sent, dim=1), 1.0).type(dtype=torch.cuda.FloatTensor)
            mask_stories.append(mask_at_t)
        return enc_stories, mask_stories


    def read_story_and_answer_question(self, batch_story, batch_question):
        """Read stories via memory network and answer related questions.
        batch_story: B x T x L (T is # of sentences and L is sentence length)
        batch_question: B x L

        Return:- answer
        """
        enc_stories, mask_stories = self.encode_story(batch_story)
        enc_query = self.encode_sent(batch_question, query=True)  # B x H

        memorized_stories = self.memory_net.read_story(enc_stories, mask_stories)  # M x B x H

        # Get softmax scores for memories
        memory_scores = torch.sum(memorized_stories * enc_query, dim=2)  # M x B
        softmax_scores = nn.functional.softmax(memory_scores, dim=0) # M x B

        # Get story representation corresponding to query
        softmax_scores = torch.unsqueeze(softmax_scores, 2)  # M x B x 1
        weighted_memories = softmax_scores * memorized_stories  # M x B x H
        story_repr = torch.sum(weighted_memories, dim=0)  # B x H

        # Output scores
        activation_input = enc_query + torch.mm(story_repr, self.H)  # B x H
        pred = torch.mm(self.activation(activation_input), self.R)  # B x V

        return pred

    def get_loss(self, preds, target):
        """Return the loss function given the prediction and correct labels.
        preds: B x V
        target: B (integer valued)
        """
        loss = nn.functional.cross_entropy(input=preds, target=target)
        return loss