예제 #1
0
    def get_data(self, datasets):
        x_lst = []
        y_lst = []
        for dataset in datasets:
            for batch in dataset:
                x_lst.append(self.transform(batch.text))
                y_lst.append(batch.label.values)

        X = torch.cat(x_lst)
        Y = torch.cat(y_lst)
        return X, Y
예제 #2
0
    def forward(self, batch):
        batch = batch.transpose(0, 1)

        embeddings = self.dropout(self.embedding(batch))
        _, (hidden, _) = self.rnn(embeddings)

        hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
        return self.out(self.dropout(hidden).squeeze(0))
예제 #3
0
    def transform(self, batch):
        individual_results = [model(batch) for model in self.pretrained_models]
        x = torch.stack(individual_results, dim=1)
        batch = batch.values.transpose(0, 1)

        if self.include_bag_of_words:
            bags_of_words = []
            for sent in batch:
                bags_of_words.append(
                    torch.bincount(sent, minlength=len(TEXT.vocab)))

            bags_of_words = torch.stack(bags_of_words).float()
            x = torch.cat((x, bags_of_words), dim=1)

        if self.include_cbow:
            cbows = []
            for sent in batch:
                cbows.append(EMBEDDINGS[sent].sum(dim=0))
            cbows = torch.stack(cbows)
            x = torch.cat((x, cbows), dim=1)

        return x
    def get_data(self, datasets):
        x_lst = []
        y_lst = []

        for dataset in datasets:
            for batch in dataset:
                sentences = batch.text.transpose('batch',
                                                 'seqlen').values.clone()
                y_lst.append(batch.label.values.float())
                for sent in sentences:
                    x_lst.append(self.transform(sent))

        X = torch.stack(x_lst)
        Y = torch.cat(y_lst)

        return (X, Y)
예제 #5
0
    def formatBatch(self, batch):
        samples_list = []
        results_list = []
        num_words = 0
        for i in range(len(batch)):
            sent = batch.text[{"batch": i}].values.data
            targets = batch.target.get("batch", i).values.data

            num_words += len(sent)
            padded_sent = torch.nn.functional.pad(
                sent, (self.n - 2, BATCHSIZE - len(sent)), value=0)
            nextwords = torch.nn.functional.pad(targets,
                                                (0, BATCHSIZE - len(targets)),
                                                value=0)

            samples_list.append(padded_sent)
            results_list.append(nextwords)

        allsamples = torch.stack(samples_list)
        allresults = torch.cat(results_list)

        return (allsamples, allresults)