def __init__(self,
                 json,
                 text_encoder=None,
                 label_encoder=None,
                 vocab=None,
                 mode='train'):
        '''
        Initialization
        Arguments:
        json: Json file containing the data. 
            Structure of json file:
            e.g: 
                 json: {'data' : [{'id': filename,
                                   'title': title of page,
                                   'toc': [list of items in table of contents section of wikipage],
                                   'intro':introduction of wiki page,
                                   'label':'positive'/'negative' flag}]
                        }
            Labels-required only when mode = 'train'
        text_encoder: encoder object that encodes tokens to their unique integer ids
        label_encoder: encoder object that encodes labels to their unique integer ids
        vocab: external vocabulary used to intialize the text encoder. If vocab = None, it would be generated based on tokens from the datasets provided
        mode: 'train' or 'inference': in case of mode == 'inference', the dataset object skips the labels
        '''
        self.data = json
        assert 'data' in self.data

        # Define the mode in which the dataset object is to be used
        self.mode = mode

        # Define text encoder and vocabulary
        if text_encoder:
            self._text_encoder = text_encoder
            self._vocab = self._text_encoder.vocab
        elif vocab:
            self._vocab = vocab
            self._text_encoder = StaticTokenizerEncoder(self._vocab,
                                                        append_eos=False,
                                                        tokenize=self.split)
        else:
            self._vocab = self.create_vocab()
            self._text_encoder = StaticTokenizerEncoder(self._vocab,
                                                        append_eos=False,
                                                        tokenize=self.split)

        self._vocab_size = self._text_encoder.vocab_size

        # Define label encoder
        if self.mode == 'train':
            if label_encoder:
                self._label_encoder = label_encoder
            else:
                self._label_encoder = LabelEncoder(
                    [sample['label'] for sample in self.data['data']])

            self._label_size = self._label_encoder.vocab_size

        else:
            self._label_encoder = None
            self._label_size = None
    def __build_model(self) -> None:
        """ Init BERT model + tokenizer + classification head."""
        self.bert = AutoModel.from_pretrained(self.hparams.encoder_model,
                                              output_hidden_states=True,
                                              num_labels=11)

        # set the number of features our encoder model will return...
        if self.hparams.encoder_model == "google/bert_uncased_L-2_H-128_A-2":
            self.encoder_features = 128
        else:
            self.encoder_features = 768

        # Tokenizer
        self.tokenizer = BERTTextEncoder("bert-base-uncased")

        # Label Encoder
        self.label_encoder = LabelEncoder(self.hparams.label_set.split(","),
                                          reserved_labels=[])
        self.label_encoder.unknown_index = None

        # Classification head
        self.classification_head = nn.Sequential(
            nn.Linear(self.encoder_features, self.encoder_features * 2),
            nn.Tanh(),
            nn.Linear(self.encoder_features * 2, self.encoder_features),
            nn.Tanh(),
            nn.Linear(self.encoder_features, self.label_encoder.vocab_size),
        )
def encode_ner_y(y_ner_list_train, y_ner_list_test, CLASS_COUNT_DICT):
    y_ner_encoder = LabelEncoder(sample=CLASS_COUNT_DICT.keys())
    y_ner_encoded_train = [[
        y_ner_encoder.encode(label) for label in label_list
    ] for label_list in y_ner_list_train]
    y_ner_encoded_train = [torch.stack(tens) for tens in y_ner_encoded_train]
    y_ner_padded_train = torch.LongTensor(
        pad_sequence(y_ner_encoded_train, MAX_SENTENCE_LEN + 1))

    y_ner_encoded_test = [[
        y_ner_encoder.encode(label) for label in label_list
    ] for label_list in y_ner_list_test]
    y_ner_encoded_test = [torch.stack(tens) for tens in y_ner_encoded_test]
    y_ner_padded_test = torch.LongTensor(
        pad_sequence(y_ner_encoded_test, MAX_SENTENCE_LEN + 1))

    if y_ner_padded_train.shape[1] > y_ner_padded_test.shape[1]:
        y_ner_padded_test = torch.cat(
            (
                y_ner_padded_test,
                torch.zeros(
                    y_ner_padded_test.shape[0],
                    y_ner_padded_train.shape[1] - y_ner_padded_test.shape[1],
                ),
            ),
            dim=1,
        ).type(torch.long)

    return y_ner_padded_train, y_ner_padded_test
Beispiel #4
0
    def prepare_encoder(self, dataset, parser):
        words = get_words(dataset, tokenize)
        chars = get_characters(dataset, tokenize)
        samples = get_samples(dataset, parser)

        qencoder = LabelEncoder(words, 2)
        cencoder = LabelEncoder(chars, 0)
        aencoder = ActionSequenceEncoder(samples, 2)
        return qencoder, cencoder, aencoder
def test_label_encoder_known(label_encoder):
    input_ = 'symbols/namesake/named_after'
    sample = [
        'people/deceased_person/place_of_death',
        'symbols/name_source/namesakes'
    ]
    sample.append(input_)
    label_encoder = LabelEncoder(sample)
    output = label_encoder.encode(input_)
    assert label_encoder.decode(output) == input_
Beispiel #6
0
def convert_examples_to_features(
        examples: List[InputExample],
        tokenizer: PreTrainedTokenizer,
        slot_label_encoder: LabelEncoder,
        intent_label_encoder: LabelEncoder,
        max_length: Optional[int] = None
):
    if max_length is None:
        max_length = tokenizer.max_len

    # slot / intent label id generate.
    truncation_size = max_length // 2
    features = []
    for example in examples:
        tmp_label_slot_origin = slot_label_encoder.batch_encode(example.label_slot_raw).numpy().tolist()
        if example.label_intent_raw:
            tmp_label_intent_origin = intent_label_encoder.batch_encode(example.label_intent_raw).numpy().tolist()
        else:
            tmp_label_intent_origin = []
        # 手动对query/context进行截断,截断长度分别为`max_length // 2`
        tmp_text_query = example.text_query[:truncation_size]  # query向后截断
        tmp_label_slot_origin = tmp_label_slot_origin[:truncation_size]
        tmp_text_context = example.text_context[-truncation_size:]  # context向前截断
        # TODO: 可选query/context顺序交换
        tokenizer_outputs = tokenizer.encode_plus(list(tmp_text_query), list(tmp_text_context), padding='max_length',
                                                  max_length=max_length)
        # 处理slot_mask
        tmp_slot_mask = np.asarray([0] * max_length)
        tmp_slot_mask[1:len(tmp_text_query) + 1] = 1
        # tmp_slot_mask[len(tmp_text_context)+2: len(tmp_text_context)+len(tmp_text_query)+2] = 1
        tmp_slot_mask = list(tmp_slot_mask)

        # 处理tmp_label_slot
        tmp_label_slot = [0] * max_length
        tmp_label_slot[1:len(tmp_text_query) + 1] = tmp_label_slot_origin

        # 处理tmp_label_intent
        tmp_label_intent = np.asarray([0] * intent_label_encoder.vocab_size)
        tmp_label_intent[tmp_label_intent_origin] = 1
        tmp_label_intent = list(tmp_label_intent)

        feature = InputFeatures(**tokenizer_outputs, slot_mask=tmp_slot_mask, label_slot=tmp_label_slot,
                                label_intent=tmp_label_intent)
        features.append(feature)

    for i, example in enumerate(examples[:5]):
        logger.info("*** Example ***")
        logger.info("guid: %s" % example.guid)
        logger.info("features: %s" % features[i])
    return features
Beispiel #7
0
    def from_file(self, fasta_file, reduce_to_binary=False):
        sequence_labels, sequence_strs = [], []
        cur_seq_label = None
        buf = []

        def _flush_current_seq(reduce_to_binary):
            nonlocal cur_seq_label, buf
            if cur_seq_label is None:
                return
            if reduce_to_binary:
                if cur_seq_label == "0":
                    sequence_labels.append(cur_seq_label)
                else:
                    sequence_labels.append("1")
            else:
                sequence_labels.append(cur_seq_label)
            sequence_strs.append("".join(buf))
            cur_seq_label = None
            buf = []

        with open(fasta_file, "r") as infile:
            for line_idx, line in enumerate(infile):
                if line.startswith(">"):  # label line
                    _flush_current_seq(reduce_to_binary)
                    line = line[1:].strip()
                    if len(line) > 0:
                        cur_seq_label = line.split("|")[-1]
                    else:
                        cur_seq_label = f"seqnum{line_idx:09d}"
                else:  # sequence line
                    buf.append(" ".join(line.strip()))

        _flush_current_seq(reduce_to_binary)

        # More usefull is to check if we have equal number of sequences and labels
        assert len(sequence_strs) == len(sequence_labels)

        # Create label encoder from unique strings in the label list
        self.label_encoder = LabelEncoder(np.unique(sequence_labels),
                                          reserved_labels=[])

        self.data = []
        for i in range(len(sequence_strs)):
            self.data.append({
                "seq": str(sequence_strs[i]),
                "label": str(sequence_labels[i])
            })

        return self.data
Beispiel #8
0
 def get_label_encoder(self):
     label_encoder = {}
     for task in self.task_list:
         label_encoder[task] = LabelEncoder([*self.task_labels[task]],
                                            reserved_labels=['unknown'],
                                            unknown_index=0)
     return label_encoder
Beispiel #9
0
        def __init__(self, classifier_instance):
            super().__init__()
            self.hparams = classifier_instance.hparams

            if self.hparams.transformer_type == 'longformer':
                self.hparams.batch_size = 1
            self.classifier = classifier_instance

            self.transformer_type = self.hparams.transformer_type

            self.n_labels = 50
            self.top_codes = pd.read_csv(
                self.hparams.train_csv)['ICD9_CODE'].value_counts(
                )[:self.n_labels].index.tolist()
            logger.warning(
                f'Classifying against the top {self.n_labels} most frequent ICD codes: {self.top_codes}'
            )

            # Label Encoder
            if self.hparams.single_label_encoding == 'default':
                self.label_encoder = LabelEncoder(np.unique(
                    self.top_codes).tolist(),
                                                  reserved_labels=[])

            self.label_encoder.unknown_index = None
Beispiel #10
0
 def __init__(self, classifier_instance):
     super().__init__()
     self.hparams = classifier_instance.hparams
     self.classifier = classifier_instance
     # Label Encoder
     self.label_encoder = LabelEncoder(pd.read_csv(
         self.hparams.train_csv).label.unique().tolist(),
                                       reserved_labels=[])
     self.label_encoder.unknown_index = None
    def __build_model(self) -> None:
        """ Init BERT model + tokenizer + classification head."""
        self.bert = BertModel.from_pretrained(
            "bert-base-uncased", output_hidden_states=True
        )

        # Tokenizer
        self.tokenizer = BERTTextEncoder("bert-base-uncased")

        # Label Encoder
        self.label_set = {"neg": 0, "pos": 1}
        self.label_encoder = LabelEncoder(
            list(self.label_set.keys()), reserved_labels=[]
        )

        # Classification head
        self.classification_head = nn.Sequential(
            nn.Dropout(self.hparams.dropout),
            nn.Linear(768, self.label_encoder.vocab_size),
        )
Beispiel #12
0
def encode_ner_y(y_ner_list_train, y_ner_list_test, class_count_dict, max_sent_len):
    """
    Tokenize y
    :param y_ner_list_train:
    :param y_ner_list_test:
    :param class_count_dict:
    :param max_sent_len:
    :return:
    """
    y_ner_encoder = LabelEncoder(sample=class_count_dict.keys())
    y_ner_encoded_train = [
        [y_ner_encoder.encode(label) for label in label_list]
        for label_list in y_ner_list_train
    ]
    y_ner_encoded_train = [torch.stack(tens) for tens in y_ner_encoded_train]
    y_ner_padded_train = torch.LongTensor(
        pad_sequence(y_ner_encoded_train, max_sent_len + 1)
    )

    y_ner_encoded_test = [
        [y_ner_encoder.encode(label) for label in label_list]
        for label_list in y_ner_list_test
    ]
    y_ner_encoded_test = [torch.stack(tens) for tens in y_ner_encoded_test]
    y_ner_padded_test = torch.LongTensor(
        pad_sequence(y_ner_encoded_test, max_sent_len + 1)
    )

    if y_ner_padded_train.shape[1] > y_ner_padded_test.shape[1]:
        y_ner_padded_test = torch.cat(
            (
                y_ner_padded_test,
                torch.zeros(
                    y_ner_padded_test.shape[0],
                    y_ner_padded_train.shape[1] - y_ner_padded_test.shape[1],
                ),
            ),
            dim=1,
        ).type(torch.long)

    return y_ner_encoder, y_ner_padded_train, y_ner_padded_test
Beispiel #13
0
    def __build_model(self) -> None:
        """ Init BERT model + tokenizer + classification head."""
        self.bert = BertModel.from_pretrained("bert-base-uncased",
                                              output_hidden_states=True)

        # Tokenizer
        self.tokenizer = BERTTextEncoder("bert-base-uncased")

        # Label Encoder
        self.label_encoder = LabelEncoder(self.hparams.label_set.split(","),
                                          reserved_labels=[])
        self.label_encoder.unknown_index = None

        # Classification head
        self.classification_head = nn.Sequential(
            nn.Linear(768, 1536),
            nn.Tanh(),
            nn.Linear(1536, 768),
            nn.Tanh(),
            nn.Linear(768, self.label_encoder.vocab_size),
        )
Beispiel #14
0
 def _create_examples(self, samples, set_type):
     """Creates examples for the training, dev and test sets."""
     examples = []
     if set_type == 'train':
         tmp_intent_counter = defaultdict(int)
         slot_label_set = set()
         intent_label_set = set()
     else:
         tmp_intent_counter = None
         slot_label_set = None
         intent_label_set = None
     for i, sample in enumerate(samples):
         guid = f"{set_type}-{i}"
         assert len(sample[0]) == len(sample[1]), f"Error slot label length occurred at {i} in {set_type} dataset."
         tmp_query_char_list = []
         label_slot_raw = []
         for tmp_char, tmp_slot_label in zip(sample[0], sample[1]):
             # 去除空格
             if is_whitespace(tmp_char):
                 continue
             tmp_query_char_list.append(tmp_char)
             label_slot_raw.append(tmp_slot_label)
         text_query = ''.join(tmp_query_char_list).lower()  # 转小写
         text_context = sample[4][0].lower()  # 转小写,现只支持单条context
         label_intent_raw = sample[2]
         if set_type == 'train':
             slot_label_set.update(label_slot_raw)
             intent_label_set.update(label_intent_raw)
             for tmp_item in label_intent_raw:
                 tmp_intent_counter[tmp_item] += 1
         examples.append(InputExample(guid=guid, text_query=text_query, text_context=text_context,
                                      label_slot_raw=label_slot_raw,
                                      label_intent_raw=label_intent_raw))
     if set_type == 'train':
         self.intent_counter = tmp_intent_counter
         self.slot_label_encoder = LabelEncoder(list(slot_label_set), reserved_labels=[], unknown_index=None)
         self.intent_label_encoder = LabelEncoder(list(intent_label_set), reserved_labels=[], unknown_index=None)
     return examples
Beispiel #15
0
    def __build_model(self) -> None:
        """ Init BERT model + tokenizer + classification head."""
        self.bert = AutoModel.from_pretrained(
            "google/bert_uncased_L-2_H-128_A-2", output_hidden_states=True
        )

        # Tokenizer
        self.tokenizer = BERTTextEncoder("google/bert_uncased_L-2_H-128_A-2")

        # Label Encoder
        self.label_encoder = LabelEncoder(
            self.hparams.label_set.split(";"), reserved_labels=[]
        )
        self.label_encoder.unknown_index = None
        
        # Classification head
        self.classification_head = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, self.label_encoder.vocab_size),
        )
Beispiel #16
0
        def __init__(self, classifier_instance):
            super().__init__()
            self.hparams = classifier_instance.hparams

            if self.hparams.transformer_type == 'longformer':
                self.hparams.batch_size = 1
            self.classifier = classifier_instance

            self.transformer_type = self.hparams.transformer_type

            # Label Encoder
            self.label_encoder = LabelEncoder(
                ['contradiction', 'entailment', 'neutral'], reserved_labels=[])

            self.label_encoder.unknown_index = None
Beispiel #17
0
 def __init__(self, classifier, data_split = None, **kwargs):
     super().__init__()
     self.hparams = classifier.hparams
     self.classifier = classifier
     if data_split is None:
         # this happens when loading a checkpoint
         data_split = (None, None, None)
     elif len(data_split) == 2:
         # add empty validation set
         tr_data, val_data = self.split(data_split[0], 0.9)
         data_split = (tr_data, val_data, data_split[1])
     self.data_split = data_split
     self.kwargs = kwargs
     self.label_encoder = LabelEncoder(
         ['pos', 'neg'],
         reserved_labels = [],
     )
def test_label_encoder_no_reserved():
    sample = [
        'people/deceased_person/place_of_death',
        'symbols/name_source/namesakes'
    ]
    label_encoder = LabelEncoder(sample,
                                 reserved_labels=[],
                                 unknown_index=None)

    label_encoder.encode('people/deceased_person/place_of_death')

    # No ``unknown_index`` defined causes ``RuntimeError`` if an unknown label is used.
    with pytest.raises(TypeError):
        label_encoder.encode('symbols/namesake/named_after')
def main():
    arg_parser = argparse.ArgumentParser("Minimalist Transformer")
    arg_parser.add_argument("mode", choices=["preprocess", "train"],
                    help="train a model or test or translate")
    arg_parser.add_argument('-f', '--config', default='default.yaml', 
                    help='Configuration file to load.')
    args = arg_parser.parse_args()
    configs = yaml.load(open(args.config).read(), Loader=yaml.FullLoader)

    if args.mode == "train":
        train_manager(configs)

    elif args.mode == "preprocess":
        train, _, test =  twitter_airline_dataset()
        text_encoder = MosesEncoder(train['source'])
        label_encoder = LabelEncoder(train['target'])
        with open('.preprocess.pkl', 'wb') as filehandler:
            pickle.dump((text_encoder, label_encoder, train, test), filehandler)

    else:
        raise ValueError("Unknown mode")
def prepare_sample(
    sample: dict, text_encoder: WhitespaceEncoder, label_encoder: LabelEncoder,
    max_length: int
) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
    """
    Function that receives a sample from the Dataset iterator and prepares t
    he input to feed the transformer model.
    :param sample: dictionary containing the inputs to build the batch 
        (e.g: [{'source': 'This flight was amazing!', 'target': 'pos'}, 
               {'source': 'I hate Iberia', 'target': 'neg'}])
    :param text_encoder: Torch NLP text encoder for tokenization and vectorization.
    :param label_encoder: Torch NLP label encoder for vectorization of labels.
    :param max_length: Max length of the input sequences.
         If a sequence passes that value it is truncated.
    """
    sample = collate_tensors(sample)
    input_seqs, input_lengths = text_encoder.batch_encode(sample['source'])
    target_seqs = label_encoder.batch_encode(sample['target'])
    # Truncate Inputs
    if input_seqs.size(1) > max_length:
        input_seqs = input_seqs[:, :max_length]
    input_mask = lengths_to_mask(input_lengths).unsqueeze(1)
    return input_seqs, input_mask, target_seqs
Beispiel #21
0
    def __build_model(self) -> None:
        self.config = AutoConfig.from_pretrained(self.hparams.encoder_model,
                                                 output_hidden_states=True)
        self.bert = AutoModel.from_pretrained(self.hparams.encoder_model,
                                              config=self.config)

        if self.hparams.encoder_model == "google/bert_uncased_L-2_H-128_A-2":
            self.encoder_features = 128
        else:
            self.encoder_features = 768

        self.tokenizer = BERTTextEncoder("bert-base-uncased")

        self.label_encoder = LabelEncoder(self.hparams.label_set.split(","),
                                          reserved_labels=[])
        self.label_encoder.unknown_index = None

        self.classification_head = nn.Sequential(
            nn.Linear(self.encoder_features, self.encoder_features * 2),
            nn.Tanh(),
            nn.Linear(self.encoder_features * 2, self.encoder_features),
            nn.Tanh(),
            nn.Linear(self.encoder_features, self.label_encoder.vocab_size),
        )
Beispiel #22
0
 def __init__(self, samples: Samples, token_threshold: int):
     reserved_labels: List[Union[Unknown,
                                 CloseVariadicFieldRule]] = [Unknown()]
     reserved_labels.append(CloseVariadicFieldRule())
     self._rule_encoder = LabelEncoder(samples.rules,
                                       reserved_labels=reserved_labels,
                                       unknown_index=0)
     self._node_type_encoder = LabelEncoder(samples.node_types)
     reserved_labels = [Unknown()]
     self._token_encoder = LabelEncoder(samples.tokens,
                                        min_occurrences=token_threshold,
                                        reserved_labels=reserved_labels,
                                        unknown_index=0)
     self.value_to_idx: Dict[str, List[int]] = {}
     for kind, value in self._token_encoder.vocab[len(reserved_labels):]:
         idx = self._token_encoder.encode((kind, value))
         if value not in self.value_to_idx:
             self.value_to_idx[value] = []
         self.value_to_idx[value].append(idx)
Beispiel #23
0
 def _build_model(self):
     self.label_encoder = LabelEncoder(self.hparams.tag_set.split(","),
                                       reserved_labels=[])
Beispiel #24
0
class Tagger(CaptionModelBase):
    """
    Tagger base class.

    :param hparams: HyperOptArgumentParser containing the hyperparameters.
    """
    def __init__(
        self,
        hparams: HyperOptArgumentParser,
    ) -> None:
        super().__init__(hparams)

    def _build_model(self):
        self.label_encoder = LabelEncoder(self.hparams.tag_set.split(","),
                                          reserved_labels=[])

    def _build_loss(self):
        """ Initializes the loss function/s. """
        weights = (np.array([
            float(x) for x in self.hparams.class_weights.split(",")
        ]) if self.hparams.class_weights != "ignore" else np.array([]))

        if self.hparams.loss == "cross_entropy":
            self.loss = nn.CrossEntropyLoss(
                reduction="sum",
                ignore_index=self.label_encoder.vocab_size,
                weight=torch.tensor(weights, dtype=torch.float32)
                if weights.any() else None,
            )
        else:
            raise Exception(f"{self.hparams.loss} is not a valid loss option.")

    def _retrieve_dataset(self, data_hparams, train=True, val=True, test=True):
        """ Retrieves task specific dataset """
        return sequence_tagging_dataset(data_hparams, train, val, test)

    @property
    def default_slot_index(self):
        """ Index of the default slot to be ignored. (e.g. 'O' in 'B-I-O' tags) """
        return 0

    def predict(self, sample: dict) -> list:
        """ Function that runs a model prediction,
        :param sample: a dictionary that must contain the the 'source' sequence.

        Return: list with predictions
        """
        if self.training:
            self.eval()

        return_dict = False
        if isinstance(sample, dict):
            sample = [sample]
            return_dict = True

        with torch.no_grad():
            model_input, _ = self.prepare_sample(sample, prepare_target=False)
            model_out = self.forward(**model_input)
            tag_logits = model_out["tags"]
            _, pred_labels = tag_logits.topk(1, dim=-1)

            for i in range(pred_labels.size(0)):
                sample_tags = pred_labels[i, :, :].view(-1)
                tags = [
                    self.label_encoder.index_to_token[sample_tags[j]]
                    for j in range(model_input["word_lengths"][i])
                ]
                sample[i]["predicted_tags"] = " ".join(tags)
                sample[i]["tagged_sequence"] = " ".join([
                    word + "/" + tag
                    for word, tag in zip(sample[i]["text"].split(), tags)
                ])

                sample[i][
                    "encoded_ground_truth_tags"] = self.label_encoder.batch_encode(
                        [tag for tag in sample[i]["tags"].split()])

                if self.hparams.ignore_last_tag:
                    if (sample[i]["encoded_ground_truth_tags"][
                            model_input["word_lengths"][i] - 1] == 1):
                        sample[i]["encoded_ground_truth_tags"][
                            model_input["word_lengths"][i] -
                            1] = self.label_encoder.vocab_size

        if return_dict:
            return sample[0]

        return sample

    def _compute_loss(self, model_out: dict, targets: dict) -> torch.tensor:
        """
        Computes Loss value according to a loss function.
        :param model_out: model specific output with predicted tag logits
            a tensor [batch_size x seq_length x num_tags]
        :param targets: Target tags [batch_size x seq_length]
        """
        logits = model_out["tags"].view(-1, model_out["tags"].size(-1))
        labels = targets["tags"].view(-1)
        return self.loss(logits, labels)

    def prepare_sample(self,
                       sample: list,
                       prepare_target: bool = True) -> (dict, dict):
        """
        Function that prepares a sample to input the model.
        :param sample: list of dictionaries.
        
        Returns:
            - dictionary with the expected model inputs.
            - dictionary with the expected target values.
        """
        sample = collate_tensors(sample)
        inputs = self.encoder.prepare_sample(sample["text"], trackpos=True)
        if not prepare_target:
            return inputs, {}

        tags, _ = stack_and_pad_tensors(
            [
                self.label_encoder.batch_encode(tags.split())
                for tags in sample["tags"]
            ],
            padding_index=self.label_encoder.vocab_size,
        )

        if self.hparams.ignore_first_title:
            first_tokens = tags[:, 0].clone()
            tags[:, 0] = first_tokens.masked_fill_(
                first_tokens == self._label_encoder.token_to_index["T"],
                self.label_encoder.vocab_size,
            )

        # TODO is this still needed ?
        if self.hparams.ignore_last_tag:
            lengths = [len(tags.split()) for tags in sample["tags"]]
            lengths = np.asarray(lengths)
            k = 0
            for length in lengths:
                if tags[k][length - 1] == 1:
                    tags[k][length - 1] = self.label_encoder.vocab_size
                k += 1

        targets = {"tags": tags}
        return inputs, targets

    def _compute_metrics(self, outputs: list) -> dict:
        """ 
        Private function that computes metrics of interest based on model predictions 
        and respective targets.
        """
        predictions = [
            batch_out["val_prediction"]["tags"] for batch_out in outputs
        ]
        targets = [batch_out["val_target"]["tags"] for batch_out in outputs]

        predicted_tags, ground_truth = [], []
        for i in range(len(predictions)):
            # Get logits and reshape predictions
            batch_predictions = predictions[i]
            logits = batch_predictions.view(-1,
                                            batch_predictions.size(-1)).cpu()
            _, pred_labels = logits.topk(1, dim=-1)

            # Reshape targets
            batch_targets = targets[i].view(-1).cpu()

            assert batch_targets.size() == pred_labels.view(-1).size()
            ground_truth.append(batch_targets)
            predicted_tags.append(pred_labels.view(-1))

        return classification_report(
            torch.cat(predicted_tags).numpy(),
            torch.cat(ground_truth).numpy(),
            padding=self.label_encoder.vocab_size,
            labels=self.label_encoder.token_to_index,
            ignore=self.default_slot_index,
        )

    @classmethod
    def add_model_specific_args(
            cls, parser: HyperOptArgumentParser) -> HyperOptArgumentParser:
        """ Parser for Estimator specific arguments/hyperparameters. 
        :param parser: HyperOptArgumentParser obj

        Returns:
            - updated parser
        """
        parser = super(Tagger, Tagger).add_model_specific_args(parser)
        parser.add_argument(
            "--tag_set",
            type=str,
            default="L,U,T",
            help="Task tags we want to use.\
                 Note that the 'default' label should appear first",
        )
        # Loss
        parser.add_argument(
            "--loss",
            default="cross_entropy",
            type=str,
            help="Loss function to be used.",
            choices=["cross_entropy"],
        )
        parser.add_argument(
            "--class_weights",
            default="ignore",
            type=str,
            help=
            'Weights for each of the classes we want to tag (e.g: "1.0,7.0,8.0").',
        )
        ## Data args:
        parser.add_argument(
            "--data_type",
            default="csv",
            type=str,
            help="The type of the file containing the training/dev/test data.",
            choices=["csv"],
        )
        parser.add_argument(
            "--train_path",
            default="data/dummy_train.csv",
            type=str,
            help="Path to the file containing the train data.",
        )
        parser.add_argument(
            "--dev_path",
            default="data/dummy_test.csv",
            type=str,
            help="Path to the file containing the dev data.",
        )
        parser.add_argument(
            "--test_path",
            default="data/dummy_test.csv",
            type=str,
            help="Path to the file containing the test data.",
        )
        parser.add_argument(
            "--loader_workers",
            default=0,
            type=int,
            help=("How many subprocesses to use for data loading. 0 means that"
                  "the data will be loaded in the main process."),
        )
        # Metric args:
        parser.add_argument(
            "--ignore_first_title",
            default=False,
            help="When used, this flag ignores T tags in the first position.",
            action="store_true",
        )
        parser.add_argument(
            "--ignore_last_tag",
            default=False,
            help="When used, this flag ignores S tags in the last position.",
            action="store_true",
        )
        return parser
Beispiel #25
0
def model_load(fn):
    global model, criterion, optimizer
    with open(fn, 'rb') as f:
        model, criterion, optimizer = torch.load(f)


from torchnlp import datasets
from torchnlp.encoders import LabelEncoder
from torchnlp.samplers import BPTTBatchSampler

print('Producing dataset...')
train, val, test = getattr(datasets, args.data)(train=True,
                                                dev=True,
                                                test=True)

encoder = LabelEncoder(train + val + test)

train_data = encoder.batch_encode(train)
val_data = encoder.batch_encode(val)
test_data = encoder.batch_encode(test)

eval_batch_size = 10
test_batch_size = 1

train_source_sampler, val_source_sampler, test_source_sampler = tuple([
    BPTTBatchSampler(d, args.bptt, args.batch_size, True, 'source')
    for d in (train, val, test)
])

train_target_sampler, val_target_sampler, test_target_sampler = tuple([
    BPTTBatchSampler(d, args.bptt, args.batch_size, True, 'target')
Beispiel #26
0
train, dev, test = snli_dataset(train=True, dev=True, test=True)

# Preprocess
for row in itertools.chain(train, dev, test):
    row['premise'] = row['premise'].lower()
    row['hypothesis'] = row['hypothesis'].lower()

# Make Encoders
sentence_corpus = [row['premise'] for row in itertools.chain(train, dev, test)]
sentence_corpus += [
    row['hypothesis'] for row in itertools.chain(train, dev, test)
]
sentence_encoder = WhitespaceEncoder(sentence_corpus)

label_corpus = [row['label'] for row in itertools.chain(train, dev, test)]
label_encoder = LabelEncoder(label_corpus)

# Encode
for row in itertools.chain(train, dev, test):
    row['premise'] = sentence_encoder.encode(row['premise'])
    row['hypothesis'] = sentence_encoder.encode(row['hypothesis'])
    row['label'] = label_encoder.encode(row['label'])

config = args
config.n_embed = sentence_encoder.vocab_size
config.d_out = label_encoder.vocab_size
config.n_cells = config.n_layers

# double the number of cells for bidirectional networks
if config.birnn:
    config.n_cells *= 2
Beispiel #27
0
class ActionSequenceEncoder:
    def __init__(self, samples: Samples, token_threshold: int):
        reserved_labels: List[Union[Unknown,
                                    CloseVariadicFieldRule]] = [Unknown()]
        reserved_labels.append(CloseVariadicFieldRule())
        self._rule_encoder = LabelEncoder(samples.rules,
                                          reserved_labels=reserved_labels,
                                          unknown_index=0)
        self._node_type_encoder = LabelEncoder(samples.node_types)
        reserved_labels = [Unknown()]
        self._token_encoder = LabelEncoder(samples.tokens,
                                           min_occurrences=token_threshold,
                                           reserved_labels=reserved_labels,
                                           unknown_index=0)
        self.value_to_idx: Dict[str, List[int]] = {}
        for kind, value in self._token_encoder.vocab[len(reserved_labels):]:
            idx = self._token_encoder.encode((kind, value))
            if value not in self.value_to_idx:
                self.value_to_idx[value] = []
            self.value_to_idx[value].append(idx)

    def decode(self, tensor: torch.LongTensor, reference: List[Token]) \
            -> Optional[ActionSequence]:
        """
        Return the action sequence corresponding to the tensor

        Parameters
        ----------
        tensor: torch.LongTensor
            The encoded tensor with the shape of
            (len(action_sequence), 3). Each action will be encoded by the tuple
            of (ID of the applied rule, ID of the inserted token,
            the index of the word copied from the reference).
            The padding value should be -1.
        reference

        Returns
        -------
        Optional[action_sequence]
            The action sequence corresponding to the tensor
            None if the action sequence cannot be generated.
        """

        retval = ActionSequence()
        for i in range(tensor.shape[0]):
            if tensor[i, 0] > 0:
                # ApplyRule
                rule = self._rule_encoder.decode(tensor[i, 0])
                retval.eval(ApplyRule(rule))
            elif tensor[i, 1] > 0:
                # GenerateToken
                kind, value = self._token_encoder.decode(tensor[i, 1])
                retval.eval(GenerateToken(kind, value))
            elif tensor[i, 2] >= 0:
                # GenerateToken (Copy)
                index = int(tensor[i, 2].numpy())
                if index >= len(reference):
                    logger.debug("reference index is out-of-bounds")
                    return None
                token = reference[index]
                retval.eval(GenerateToken(token.kind, token.raw_value))
            else:
                logger.debug("invalid actions")
                return None

        return retval

    def encode_action(self,
                      action_sequence: ActionSequence,
                      reference: List[Token]) \
            -> Optional[torch.Tensor]:
        """
        Return the tensor encoded the action sequence

        Parameters
        ----------
        action_sequence: action_sequence
            The action_sequence containing action sequence to be encoded
        reference

        Returns
        -------
        Optional[torch.Tensor]
            The encoded tensor. The shape of tensor is
            (len(action_sequence) + 1, 4). Each action will be encoded by
            the tuple of (ID of the node types, ID of the applied rule,
            ID of the inserted token, the index of the word copied from
            the reference. The padding value should be -1.
            None if the action sequence cannot be encoded.
        """
        reference_value = [token.raw_value for token in reference]
        action = \
            torch.ones(len(action_sequence.action_sequence) + 1, 4).long() \
            * -1
        for i in range(len(action_sequence.action_sequence)):
            a = action_sequence.action_sequence[i]
            parent = action_sequence.parent(i)
            if parent is not None:
                parent_action = \
                    cast(ApplyRule,
                         action_sequence.action_sequence[parent.action])
                parent_rule = cast(ExpandTreeRule, parent_action.rule)
                action[i, 0] = self._node_type_encoder.encode(
                    parent_rule.children[parent.field][1])

            if isinstance(a, ApplyRule):
                rule = a.rule
                action[i, 1] = self._rule_encoder.encode(rule)
            else:
                encoded_token = \
                    int(self._token_encoder.encode((a.kind, a.value)).numpy())

                if encoded_token != 0:
                    action[i, 2] = encoded_token

                # Unknown token
                if a.value in reference_value:
                    # TODO use kind in reference
                    action[i, 3] = \
                        reference_value.index(cast(str, a.value))

                if encoded_token == 0 and \
                        a.value not in reference_value:
                    logger.debug("cannot encode token")
                    return None

        head = action_sequence.head
        length = len(action_sequence.action_sequence)
        if head is not None:
            head_action = \
                cast(ApplyRule,
                     action_sequence.action_sequence[head.action])
            head_rule = cast(ExpandTreeRule, head_action.rule)
            action[length, 0] = self._node_type_encoder.encode(
                head_rule.children[head.field][1])

        return action

    def encode_raw_value(self, text: str) -> List[int]:
        if text in self.value_to_idx:
            return self.value_to_idx[text]
        else:
            return [self._token_encoder.encode(Unknown()).item()]

    def batch_encode_raw_value(self, texts: List[str]) -> List[List[int]]:
        return [self.encode_raw_value(text) for text in texts]

    def encode_parent(self, action_sequence) -> torch.Tensor:
        """
        Return the tensor encoded the action sequence

        Parameters
        ----------
        action_sequence: action_sequence
            The action_sequence containing action sequence to be encoded

        Returns
        -------
        torch.Tensor
            The encoded tensor. The shape of `action` tensor is
            (len(action_sequence) + 1, 4). Each action will be encoded by
            the tuple of (ID of the parent node types, ID of the
            parent-action's rule, the index of the parent action,
            the index of the field).
            The padding value should be -1.
        """
        parent_tensor = \
            torch.ones(len(action_sequence.action_sequence) + 1, 4).long() \
            * -1

        for i in range(len(action_sequence.action_sequence)):
            parent = action_sequence.parent(i)
            if parent is not None:
                parent_action = \
                    cast(ApplyRule,
                         action_sequence.action_sequence[parent.action])
                parent_rule = cast(ExpandTreeRule, parent_action.rule)
                parent_tensor[i, 0] = \
                    self._node_type_encoder.encode(parent_rule.parent)
                parent_tensor[i, 1] = self._rule_encoder.encode(parent_rule)
                parent_tensor[i, 2] = parent.action
                parent_tensor[i, 3] = parent.field

        head = action_sequence.head
        length = len(action_sequence.action_sequence)
        if head is not None:
            head_action = \
                cast(ApplyRule,
                     action_sequence.action_sequence[head.action])
            head_rule = cast(ExpandTreeRule, head_action.rule)
            parent_tensor[length, 0] = \
                self._node_type_encoder.encode(head_rule.parent)
            parent_tensor[length, 1] = self._rule_encoder.encode(head_rule)
            parent_tensor[length, 2] = head.action
            parent_tensor[length, 3] = head.field

        return parent_tensor

    def encode_tree(self, action_sequence: ActionSequence) \
            -> Union[torch.Tensor, torch.Tensor]:
        """
        Return the tensor adjacency matrix of the action sequence

        Parameters
        ----------
        action_sequence: action_sequence
            The action_sequence containing action sequence to be encoded

        Returns
        -------
        depth: torch.Tensor
            The depth of each action. The shape is (len(action_sequence),).
        adjacency_matrix: torch.Tensor
            The encoded tensor. The shape of tensor is
            (len(action_sequence), len(action_sequence)). If i th action is
            a parent of j th action, (i, j) element will be 1. the element
            will be 0 otherwise.
        """
        L = len(action_sequence.action_sequence)
        depth = torch.zeros(L)
        m = torch.zeros(L, L)

        for i in range(L):
            p = action_sequence.parent(i)
            if p is not None:
                depth[i] = depth[p.action] + 1
                m[p.action, i] = 1

        return depth, m

    def encode_each_action(self,
                           action_sequence: ActionSequence,
                           reference: List[Token],
                           max_arity: int) \
            -> torch.Tensor:
        """
        Return the tensor encoding the each action

        Parameters
        ----------
        action_sequence: action_sequence
            The action_sequence containing action sequence to be encoded
        reference
        max_arity: int

        Returns
        -------
        torch.Tensor
            The encoded tensor. The shape of tensor is
            (len(action_sequence), max_arity + 1, 3).
            [:, 0, 0] encodes the parent node type. [:, i, 0] encodes
            the node type of (i - 1)-th child node. [:, i, 1] encodes
            the token of (i - 1)-th child node. [:, i, 2] encodes the reference
            index of (i - 1)-th child node.
            The padding value is -1.
        """
        L = len(action_sequence.action_sequence)
        reference_value = [token.raw_value for token in reference]
        retval = torch.ones(L, max_arity + 1, 3).long() * -1
        for i, action in enumerate(action_sequence.action_sequence):
            if isinstance(action, ApplyRule):
                if isinstance(action.rule, ExpandTreeRule):
                    # Encode parent
                    retval[i, 0, 0] = \
                        self._node_type_encoder.encode(action.rule.parent)
                    # Encode children
                    for j, (_, child) in enumerate(
                            action.rule.children[:max_arity]):
                        retval[i, j + 1, 0] = \
                            self._node_type_encoder.encode(child)
            else:
                gentoken: GenerateToken = action
                kind = gentoken.kind
                value = gentoken.value
                encoded_token = \
                    int(self._token_encoder.encode((kind, value)).numpy())

                if encoded_token != 0:
                    retval[i, 1, 1] = encoded_token

                if value in reference_value:
                    # TODO use kind in reference
                    retval[i, 1, 2] = \
                        reference_value.index(cast(str, value))

        return retval

    def encode_path(self, action_sequence: ActionSequence, max_depth: int) \
            -> torch.Tensor:
        """
        Return the tensor encoding the each action

        Parameters
        ----------
        action_sequence: action_sequence
            The action_sequence containing action sequence to be encoded
        max_depth: int

        Returns
        -------
        torch.Tensor
            The encoded tensor. The shape of tensor is
            (len(action_sequence), max_depth).
            [i, :] encodes the path from the root node to i-th node.
            Each node represented by the rule id.
            The padding value is -1.
        """
        L = len(action_sequence.action_sequence)
        retval = torch.ones(L, max_depth).long() * -1
        for i in range(L):
            parent_opt = action_sequence.parent(i)
            if parent_opt is not None:
                p = action_sequence.action_sequence[parent_opt.action]
                if isinstance(p, ApplyRule):
                    retval[i, 0] = self._rule_encoder.encode(p.rule)
                retval[i, 1:] = retval[parent_opt.action, :max_depth - 1]

        return retval
Beispiel #28
0
class UniProtData():
    def __init__(self, tokenizer, sequence_length):
        self.tokenizer = tokenizer
        self.sequence_length = sequence_length

    def from_file(self, fasta_file, reduce_to_binary=False):
        sequence_labels, sequence_strs = [], []
        cur_seq_label = None
        buf = []

        def _flush_current_seq(reduce_to_binary):
            nonlocal cur_seq_label, buf
            if cur_seq_label is None:
                return
            if reduce_to_binary:
                if cur_seq_label == "0":
                    sequence_labels.append(cur_seq_label)
                else:
                    sequence_labels.append("1")
            else:
                sequence_labels.append(cur_seq_label)
            sequence_strs.append("".join(buf))
            cur_seq_label = None
            buf = []

        with open(fasta_file, "r") as infile:
            for line_idx, line in enumerate(infile):
                if line.startswith(">"):  # label line
                    _flush_current_seq(reduce_to_binary)
                    line = line[1:].strip()
                    if len(line) > 0:
                        cur_seq_label = line.split("|")[-1]
                    else:
                        cur_seq_label = f"seqnum{line_idx:09d}"
                else:  # sequence line
                    buf.append(" ".join(line.strip()))

        _flush_current_seq(reduce_to_binary)

        # More usefull is to check if we have equal number of sequences and labels
        assert len(sequence_strs) == len(sequence_labels)

        # Create label encoder from unique strings in the label list
        self.label_encoder = LabelEncoder(np.unique(sequence_labels),
                                          reserved_labels=[])

        self.data = []
        for i in range(len(sequence_strs)):
            self.data.append({
                "seq": str(sequence_strs[i]),
                "label": str(sequence_labels[i])
            })

        return self.data

    def prepare_sample(self,
                       sample: list,
                       prepare_target: bool = True) -> (dict, dict):
        """
        Function that prepares a sample to input the model.
        :param sample: list of dictionaries.

        Returns:
            - dictionary with the expected model inputs.
            - dictionary with the expected target labels.
        """
        sample = collate_tensors(sample)

        # Tokenize the input, return dict with 3 entries:
        #   input_ids: tokenized matrix
        #   token_input_id: matrix of 0,1 indicating if the element belongs to seq0 or eq1
        #   attention_mask: matrix of 0,1 indicating if a token ist masked (0) or not (1)
        # Convert to PT tensor
        inputs = self.tokenizer.batch_encode_plus(
            sample["seq"],
            add_special_tokens=True,
            padding=True,
            truncation=True,
            max_length=self.sequence_length,
            return_tensors="pt")

        if prepare_target is False:
            return inputs, {}

        # Prepare target:
        try:
            targets = {
                "labels": self.label_encoder.batch_encode(sample["label"])
            }
            return inputs, targets
        except RuntimeError:
            print(sample["label"])
            raise Exception("Label encoder found an unknown label.")

    def get_dataloader(self, file_path, batch_size, num_worker=4):
        data = self.from_file(file_path)

        data_loader = DataLoader(data,
                                 batch_size=batch_size,
                                 sampler=RandomSampler(data),
                                 collate_fn=self.prepare_sample,
                                 num_workers=num_worker)

        return data_loader
class BERTClassifier(pl.LightningModule):
    """
    Sample model to show how to use BERT to classify sentences.

    :param hparams: ArgumentParser containing the hyperparameters.
    """
    def __init__(self, hparams) -> None:
        super(BERTClassifier, self).__init__()
        self.hparams = hparams
        self.batch_size = hparams.batch_size

        # build model
        self.__build_model()

        # Loss criterion initialization.
        self.__build_loss()

        if hparams.nr_frozen_epochs > 0:
            self.freeze_encoder()
        else:
            self._frozen = False
        self.nr_frozen_epochs = hparams.nr_frozen_epochs

    def __build_model(self) -> None:
        """ Init BERT model + tokenizer + classification head."""
        self.bert = AutoModel.from_pretrained(self.hparams.encoder_model,
                                              output_hidden_states=True,
                                              num_labels=11)

        # set the number of features our encoder model will return...
        if self.hparams.encoder_model == "google/bert_uncased_L-2_H-128_A-2":
            self.encoder_features = 128
        else:
            self.encoder_features = 768

        # Tokenizer
        self.tokenizer = BERTTextEncoder("bert-base-uncased")

        # Label Encoder
        self.label_encoder = LabelEncoder(self.hparams.label_set.split(","),
                                          reserved_labels=[])
        self.label_encoder.unknown_index = None

        # Classification head
        self.classification_head = nn.Sequential(
            nn.Linear(self.encoder_features, self.encoder_features * 2),
            nn.Tanh(),
            nn.Linear(self.encoder_features * 2, self.encoder_features),
            nn.Tanh(),
            nn.Linear(self.encoder_features, self.label_encoder.vocab_size),
        )

    def __build_loss(self):
        """ Initializes the loss function/s. """
        self._loss = nn.CrossEntropyLoss()

    def unfreeze_encoder(self) -> None:
        """ un-freezes the encoder layer. """
        if self._frozen:
            log.info(f"\n-- Encoder model fine-tuning")
            for param in self.bert.parameters():
                param.requires_grad = True
            self._frozen = False

    def freeze_encoder(self) -> None:
        """ freezes the encoder layer. """
        for param in self.bert.parameters():
            param.requires_grad = False
        self._frozen = True

    def predict(self, sample: dict) -> dict:
        """ Predict function.
        :param sample: dictionary with the text we want to classify.

        Returns:
            Dictionary with the input text and the predicted label.
        """
        if self.training:
            self.eval()

        with torch.no_grad():
            model_input, _ = self.prepare_sample([sample],
                                                 prepare_target=False)
            model_out = self.forward(**model_input)
            logits = model_out["logits"].numpy()
            predicted_labels = [
                self.label_encoder.index_to_token[prediction]
                for prediction in np.argmax(logits, axis=1)
            ]
            sample["predicted_label"] = predicted_labels[0]

        return sample

    def forward(self, tokens, lengths):
        """ Usual pytorch forward function.
        :param tokens: text sequences [batch_size x src_seq_len]
        :param lengths: source lengths [batch_size]

        Returns:
            Dictionary with model outputs (e.g: logits)
        """
        tokens = tokens[:, :lengths.max()]
        # When using just one GPU this should not change behavior
        # but when splitting batches across GPU the tokens have padding
        # from the entire original batch
        mask = lengths_to_mask(lengths, device=tokens.device)
        # Run BERT model.
        word_embeddings = self.bert(tokens, mask)[0]

        # Average Pooling
        word_embeddings = mask_fill(0.0, tokens, word_embeddings,
                                    self.tokenizer.padding_index)
        sentemb = torch.sum(word_embeddings, 1)
        sum_mask = mask.unsqueeze(-1).expand(
            word_embeddings.size()).float().sum(1)
        sentemb = sentemb / sum_mask

        return {"logits": self.classification_head(sentemb)}

    def loss(self, predictions: dict, targets: dict) -> torch.tensor:
        """
        Computes Loss value according to a loss function.
        :param predictions: model specific output. Must contain a key 'logits' with
            a tensor [batch_size x 1] with model predictions
        :param labels: Label values [batch_size]

        Returns:
            torch.tensor with loss value.
        """
        return self._loss(predictions["logits"], targets["labels"])

    def prepare_sample(self,
                       sample: list,
                       prepare_target: bool = True) -> (dict, dict):
        """
        Function that prepares a sample to input the model.
        :param sample: list of dictionaries.

        Returns:
            - dictionary with the expected model inputs.
            - dictionary with the expected target labels.
        """
        sample = collate_tensors(sample)
        tokens, lengths = self.tokenizer.batch_encode(sample["text"])

        inputs = {"tokens": tokens, "lengths": lengths}

        if not prepare_target:
            return inputs, {}

        # Prepare target:
        try:
            targets = {
                "labels": self.label_encoder.batch_encode(sample["label"])
            }
            return inputs, targets
        except RuntimeError:
            raise Exception("Label encoder found an unknown label.")

    def training_step(self, batch: tuple, batch_nb: int, *args,
                      **kwargs) -> dict:
        """
        Runs one training step. This usually consists in the forward function followed
            by the loss function.

        :param batch: The output of your dataloader.
        :param batch_nb: Integer displaying which batch this is

        Returns:
            - dictionary containing the loss and the metrics to be added to the lightning logger.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val = loss_val.unsqueeze(0)

        tqdm_dict = {"train_loss": loss_val}
        output = OrderedDict({
            "loss": loss_val,
            "progress_bar": tqdm_dict,
            "log": tqdm_dict
        })

        # can also return just a scalar instead of a dict (return loss_val)
        return output

    def validation_step(self, batch: tuple, batch_nb: int, *args,
                        **kwargs) -> dict:
        """ Similar to the training step but with the model in eval mode.

        Returns:
            - dictionary passed to the validation_end function.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        y = targets["labels"]
        y_hat = model_out["logits"]

        # acc
        labels_hat = torch.argmax(y_hat, dim=1)
        val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
        val_acc = torch.tensor(val_acc)

        if self.on_gpu:
            val_acc = val_acc.cuda(loss_val.device.index)

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val = loss_val.unsqueeze(0)
            val_acc = val_acc.unsqueeze(0)

        output = OrderedDict({
            "val_loss": loss_val,
            "val_acc": val_acc,
        })

        # can also return just a scalar instead of a dict (return loss_val)
        return output

    def validation_end(self, outputs: list) -> dict:
        """ Function that takes as input a list of dictionaries returned by the validation_step
        function and measures the model performance accross the entire validation set.

        Returns:
            - Dictionary with metrics to be added to the lightning logger.
        """
        val_loss_mean = 0
        val_acc_mean = 0
        for output in outputs:
            val_loss = output["val_loss"]

            # reduce manually when using dp
            if self.trainer.use_dp or self.trainer.use_ddp2:
                val_loss = torch.mean(val_loss)
            val_loss_mean += val_loss

            # reduce manually when using dp
            val_acc = output["val_acc"]
            if self.trainer.use_dp or self.trainer.use_ddp2:
                val_acc = torch.mean(val_acc)

            val_acc_mean += val_acc

        val_loss_mean /= len(outputs)
        val_acc_mean /= len(outputs)
        tqdm_dict = {"val_loss": val_loss_mean, "val_acc": val_acc_mean}
        result = {
            "progress_bar": tqdm_dict,
            "log": tqdm_dict,
            "val_loss": val_loss_mean,
        }
        return result

    def configure_optimizers(self):
        """ Sets different Learning rates for different parameter groups. """
        parameters = [
            {
                "params": self.classification_head.parameters()
            },
            {
                "params": self.bert.parameters(),
                "lr": self.hparams.encoder_learning_rate,
            },
        ]
        optimizer = optim.Adam(parameters, lr=self.hparams.learning_rate)
        return [optimizer], []

    def on_epoch_end(self):
        """ Pytorch lightning hook """
        if self.current_epoch + 1 >= self.nr_frozen_epochs:
            self.unfreeze_encoder()

    def __retrieve_dataset(self, train=True, val=True, test=True):
        """ Retrieves task specific dataset """
        return sentiment_analysis_dataset(self.hparams, train, val, test)

    @pl.data_loader
    def train_dataloader(self) -> DataLoader:
        """ Function that loads the train set. """
        self._train_dataset = self.__retrieve_dataset(val=False, test=False)[0]
        return DataLoader(
            dataset=self._train_dataset,
            sampler=RandomSampler(self._train_dataset),
            batch_size=self.hparams.batch_size,
            collate_fn=self.prepare_sample,
            num_workers=self.hparams.loader_workers,
        )

    @pl.data_loader
    def val_dataloader(self) -> DataLoader:
        """ Function that loads the validation set. """
        self._dev_dataset = self.__retrieve_dataset(train=False, test=False)[0]
        return DataLoader(
            dataset=self._dev_dataset,
            batch_size=self.hparams.batch_size,
            collate_fn=self.prepare_sample,
            num_workers=self.hparams.loader_workers,
        )

    @pl.data_loader
    def test_dataloader(self) -> DataLoader:
        """ Function that loads the validation set. """
        self._test_dataset = self.__retrieve_dataset(train=False, val=False)[0]
        return DataLoader(
            dataset=self._test_dataset,
            batch_size=self.hparams.batch_size,
            collate_fn=self.prepare_sample,
            num_workers=self.hparams.loader_workers,
        )

    @classmethod
    def add_model_specific_args(
            cls, parser: HyperOptArgumentParser) -> HyperOptArgumentParser:
        """ Parser for Estimator specific arguments/hyperparameters.
        :param parser: HyperOptArgumentParser obj

        Returns:
            - updated parser
        """
        parser.add_argument(
            "--encoder_model",
            default="bert-base-uncased",
            type=str,
            help="Encoder model to be used.",
        )
        parser.add_argument(
            "--encoder_learning_rate",
            default=1e-05,
            type=float,
            help="Encoder specific learning rate.",
        )
        parser.add_argument(
            "--learning_rate",
            default=3e-05,
            type=float,
            help="Classification head learning rate.",
        )
        parser.opt_list(
            "--nr_frozen_epochs",
            default=1,
            type=int,
            help="Number of epochs we want to keep the encoder model frozen.",
            tunable=True,
            options=[0, 1, 2, 3, 4, 5],
        )
        # Data Args:
        parser.add_argument(
            "--label_set",
            default="pos,neg",
            type=str,
            help="Classification labels set.",
        )
        parser.add_argument(
            "--train_csv",
            default="data/imdb_reviews_train.csv",
            type=str,
            help="Path to the file containing the train data.",
        )
        parser.add_argument(
            "--dev_csv",
            default="data/imdb_reviews_test.csv",
            type=str,
            help="Path to the file containing the dev data.",
        )
        parser.add_argument(
            "--test_csv",
            default="data/imdb_reviews_test.csv",
            type=str,
            help="Path to the file containing the dev data.",
        )
        parser.add_argument(
            "--loader_workers",
            default=8,
            type=int,
            help="How many subprocesses to use for data loading. 0 means that \
                the data will be loaded in the main process.",
        )
        return parser
def label_encoder():
    sample = [
        'people/deceased_person/place_of_death',
        'symbols/name_source/namesakes'
    ]
    return LabelEncoder(sample)