Example #1
0
 def __add__(self, output_batch):
     new_batch_stats = {}
     for k in set(self.batch_stats.keys()).union(
             output_batch.batch_stats.keys()):
         new_batch_stats[k] = self.batch_stats[k]\
                                  + output_batch.batch_stats[k]
     new_batch_length = self.batch_length + output_batch.batch_length
     if self.outputs is not None:
         new_outputs = {}
         for k in set(self.outputs.keys()).union(
                 output_batch.outputs.keys()):
             new_outputs[k] = pad_and_concat(
                 [self.outputs[k], output_batch.outputs[k]])
             new_outputs[k] = new_outputs[k].view(-1,
                                                  *new_outputs[k].shape[2:])
     else:
         new_outputs = None
     if self.batch is not None:
         raise NotImplementedError
     else:
         new_batch = None
     return self.__class__(new_batch_length,
                           new_batch_stats,
                           batch=new_batch,
                           outputs=new_outputs)
Example #2
0
 def collate_datapoints(datapoints):
     return {
         k: pad_and_concat([datapoint[k] for datapoint in datapoints])
         if isinstance(datapoints[0][k], torch.Tensor) else
         [datapoint[k] for datapoint in datapoints]
         for k in datapoints[0].keys()
     }
Example #3
0
 def stats(cls, batch, outputs):
     returned_stats = statistics_func(**outputs)
     counts = pad_and_concat([returned_stats['true_positives'],
                              returned_stats['positives'],
                              returned_stats['relevants']],
                             auto=False)
     del returned_stats['true_positives']
     del returned_stats['positives']
     del returned_stats['relevants']
     returned_stats = {k:v.item() for k,v in returned_stats.items()}
     returned_stats['counts'] = counts
     return returned_stats
Example #4
0
    def __init__(self,
                 raw_datapoint,
                 tokenizer,
                 codes,
                 hierarchy,
                 ancestors=False,
                 code_id=False,
                 code_description=False,
                 code_linearization=False,
                 description_linearization=False,
                 description_embedding_linearization=False,
                 resample_neg_proportion=None,
                 counts=None,
                 tfidf_tokenizer=False,
                 filter=lambda x: True):
        self.raw_datapoint = raw_datapoint
        self.datapoint = {}
        self.observed = []

        # process_reports
        num_sentences = -1000
        if 'num_sentences' in raw_datapoint.keys():
            num_sentences = raw_datapoint['num_sentences']
        sentences, self.sentence_spans = get_sentences(
            raw_datapoint['reports'],
            num_sentences=num_sentences,
            filter=filter)
        self.tokenized_sentences = [
            tokenizer.tokenize(sent) for sent in sentences
        ]
        if not tfidf_tokenizer:
            self.datapoint['article_sentences'] = pad_and_concat([
                torch.tensor(tokenizer.convert_tokens_to_ids(sent))
                for sent in self.tokenized_sentences
            ])
            self.datapoint['article_sentences_lengths'] = torch.tensor(
                [len(sent) for sent in self.tokenized_sentences])
            self.observed += ['article_sentences', 'article_sentences_lengths']

        # get code_id
        # this happens regardless of whether it is observed because it might be needed for supervision
        if 'targets' in raw_datapoint.keys():
            # only gets codes it when given targets
            targets = raw_datapoint['targets']
            if 'labels' in raw_datapoint.keys():
                labels = raw_datapoint['labels']
                if ancestors or resample_neg_proportion is not None:
                    positives, negatives = get_pos_neg(targets, labels)
                if ancestors:
                    positives = hierarchy.ancestors(positives)
                    negatives = hierarchy.ancestors(negatives,
                                                    stop_nodes=positives)
                if resample_neg_proportion is not None:
                    # sample negative according to positive prior for that code
                    total_negatives = counts.negative.sum()
                    individual_probs = np.array([
                        (counts.positive[c] /
                         (counts.positive.sum())) * (1 / counts.negative[c])
                        for c in negatives
                    ])
                    individual_probs = individual_probs / individual_probs.sum(
                    )
                    negatives = list(
                        np.random.choice(negatives,
                                         size=len(negatives),
                                         p=individual_probs))
                    keep = np.random.binomial(len(negatives),
                                              resample_neg_proportion)
                    negatives = negatives[:keep]
                if ancestors or resample_neg_proportion is not None:
                    targets, labels = get_targets_labels(positives, negatives)
                self.datapoint['labels'] = torch.tensor(labels)
            self.datapoint['codes'] = torch.tensor(
                [codes[code_str] for code_str in targets])
            self.datapoint['num_codes'] = torch.tensor(
                self.datapoint['codes'].size(0))
            self.observed += ['num_codes']
        elif 'labels' in raw_datapoint.keys():
            raise Exception

        # get observed
        if code_id:
            # needs targets
            if 'targets' not in raw_datapoint.keys():
                raise Exception
            self.observed += ['codes']
        if code_description:
            # get description
            # doesn't need targets as long as it has queries
            if 'targets' in raw_datapoint.keys():
                descriptions = (get_description_string(t, hierarchy)
                                for t in targets)
            else:
                descriptions = raw_datapoint['descriptions']
                # if targets were not given, you still need num_codes
                self.datapoint['num_codes'] = torch.tensor(len(descriptions))
                self.observed += ['num_codes']
            tokenized_descriptions = [
                tokenizer.tokenize(d) for d in descriptions
            ]
            if not tfidf_tokenizer:
                self.datapoint['code_description'] = pad_and_concat([
                    torch.tensor(tokenizer.convert_tokens_to_ids(d))
                    for d in tokenized_descriptions
                ])
                self.datapoint['code_description_length'] = torch.tensor(
                    [len(d) for d in tokenized_descriptions])
                self.observed += [
                    'code_description', 'code_description_length'
                ]
        if code_linearization:
            # get code linearization
            # needs targets
            if 'targets' not in raw_datapoint.keys():
                raise Exception
            linearized_codes = [
                hierarchy.linearize(target) for target in targets
            ]
            self.datapoint['linearized_codes'] = pad_and_concat([
                torch.tensor(linearized_code)
                for linearized_code in linearized_codes
            ])
            self.datapoint['linearized_codes_lengths'] = torch.tensor(
                [len(linearized_code) for linearized_code in linearized_codes])
            self.observed += ['linearized_codes', 'linearized_codes_lengths']
        if description_linearization:
            # get description
            # doesn't need targets as long as it has queries
            if 'targets' in raw_datapoint.keys():
                descriptions = [
                    get_description_linearization(t, hierarchy)
                    for t in targets
                ]
            else:
                descriptions = raw_datapoint['description_linearizations']
                # if targets were not given, you still need num_codes
                self.datapoint['num_codes'] = torch.tensor(len(descriptions))
                self.observed += ['num_codes']
            tokenized_description_linearizations = [
                tokenizer.tokenize(d) for d in descriptions
            ]
            if not tfidf_tokenizer:
                self.datapoint['linearized_descriptions'] = pad_and_concat([
                    torch.tensor(tokenizer.convert_tokens_to_ids(d))
                    for d in tokenized_description_linearizations
                ])
                self.datapoint[
                    'linearized_descriptions_lengths'] = torch.tensor(
                        [len(d) for d in tokenized_description_linearizations])
                self.observed += [
                    'linearized_descriptions',
                    'linearized_descriptions_lengths'
                ]
        if description_embedding_linearization:
            # get description
            # doesn't need targets as long as it has queries
            if 'targets' in raw_datapoint.keys():
                descriptions = [
                    get_description_embedding_linearization(t, hierarchy)
                    for t in targets
                ]
            else:
                raise NotImplementedError  # interface doesn't produce valid queries for this yet
                descriptions = raw_datapoint['queries']
                # if targets were not given, you still need num_codes
                self.datapoint['num_codes'] = torch.tensor(len(descriptions))
                self.observed += ['num_codes']
            tokenized_description_embedding_linearizations = [
                tokenizer.tokenize(d) for d in descriptions
            ]
            if not tfidf_tokenizer:
                self.datapoint[
                    'linearized_description_embeddings'] = pad_and_concat([
                        torch.tensor(tokenizer.convert_tokens_to_ids(d))
                        for d in tokenized_description_embedding_linearizations
                    ])
                self.datapoint[
                    'linearized_description_embeddings_lengths'] = torch.tensor(
                        [
                            len(d) for d in
                            tokenized_description_embedding_linearizations
                        ])
                self.observed += [
                    'linearized_description_embeddings',
                    'linearized_description_embeddings_lengths'
                ]

        if tfidf_tokenizer:
            if code_description:
                tfidf_matrix = tokenizer.convert_tokens_to_ids(
                    self.tokenized_sentences + tokenized_descriptions)
                self.datapoint['code_description'] = torch.tensor(
                    tfidf_matrix[len(self.tokenized_sentences):])
                self.observed += ['code_description']
                tfidf_matrix = tfidf_matrix[:len(self.tokenized_sentences)]
            else:
                tfidf_matrix = tokenizer.convert_tokens_to_ids(
                    self.tokenized_sentences)
            self.datapoint['article_sentences'] = torch.tensor(tfidf_matrix)
            self.datapoint['num_sentences'] = torch.tensor(
                len(self.tokenized_sentences))
            self.observed += ['article_sentences', 'num_sentences']