Ejemplo n.º 1
0
def evaluate(gold_file,
             pred_file,
             do_enhanced_collapse_empty_nodes=False,
             do_copy_cols=True):
    """Evaluate using official CoNLL-X evaluation script (Yuval Krymolowski)

    Args:
      gold_file(str): The gold conllx file
      pred_file(str): The pred conllx file
      do_enhanced_collapse_empty_nodes:  (Default value = False)
      do_copy_cols:  (Default value = True)

    Returns:

    
    """
    if do_enhanced_collapse_empty_nodes:
        gold_file = enhanced_collapse_empty_nodes(gold_file)
        pred_file = enhanced_collapse_empty_nodes(pred_file)
    if do_copy_cols:
        fixed_pred_file = pred_file.replace('.conllu', '.fixed.conllu')
        copy_cols(gold_file, pred_file, fixed_pred_file)
    else:
        fixed_pred_file = pred_file
    args = SerializableDict()
    args.enhancements = '0'
    args.gold_file = gold_file
    args.system_file = fixed_pred_file
    return iwpt20_xud_eval.evaluate_wrapper(args)
Ejemplo n.º 2
0
    def _savable_config(self):
        def convert(k, v):
            if hasattr(v, 'config'):
                v = v.config
            if isinstance(v, (set, tuple)):
                v = list(v)
            return k, v

        config = SerializableDict(
            convert(k, v) for k, v in sorted(self.config.items()))
        config.update({
            # 'create_time': now_datetime(),
            'classpath': classpath_of(self),
            'elit_version': elit.__version__,
        })
        return config
Ejemplo n.º 3
0
 def __init__(self, **kwargs) -> None:
     super().__init__()
     self.model: Optional[torch.nn.Module] = None
     self.config = SerializableDict(**kwargs)
     self.vocabs = VocabDict()
Ejemplo n.º 4
0
 def load_vocab(self, save_dir, filename='vocab.json'):
     save_dir = get_resource(save_dir)
     vocab = SerializableDict()
     vocab.load_json(os.path.join(save_dir, filename))
     self.vocab.copy_from(vocab)
Ejemplo n.º 5
0
 def save_vocab(self, save_dir, filename='vocab.json'):
     vocab = SerializableDict()
     vocab.update(self.vocab.to_dict())
     vocab.save_json(os.path.join(save_dir, filename))
Ejemplo n.º 6
0
 def load_vocabs(self, save_dir, filename='vocabs.json', vocab_cls=Vocab):
     save_dir = get_resource(save_dir)
     vocabs = SerializableDict()
     vocabs.load_json(os.path.join(save_dir, filename))
     self._load_vocabs(self, vocabs, vocab_cls)
Ejemplo n.º 7
0
 def save_vocabs(self, save_dir, filename='vocabs.json'):
     vocabs = SerializableDict()
     for key, value in self.items():
         if isinstance(value, Vocab):
             vocabs[key] = value.to_dict()
     vocabs.save_json(os.path.join(save_dir, filename))
Ejemplo n.º 8
0
    def __call__(self, sample: dict):
        input_tokens = sample[self.input_key]
        input_is_str = isinstance(input_tokens, str)
        tokenizer = self.tokenizer
        ret_token_span = self.ret_token_span
        if input_is_str:  # This happens in a tokenizer component where the raw sentence is fed.

            def tokenize_str(input_tokens, add_special_tokens=True):
                if tokenizer.is_fast:
                    encoding = tokenizer.encode_plus(
                        input_tokens,
                        return_offsets_mapping=True,
                        add_special_tokens=add_special_tokens).encodings[0]
                    subtoken_offsets = encoding.offsets
                    if add_special_tokens:
                        subtoken_offsets = subtoken_offsets[1:-1]
                    input_tokens = encoding.tokens
                    input_ids = encoding.ids
                else:
                    input_tokens = tokenizer.tokenize(input_tokens)
                    subtoken_offsets = input_tokens
                    if add_special_tokens:
                        input_tokens = [self.cls_token
                                        ] + input_tokens + [self.sep_token]
                    input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
                return input_tokens, input_ids, subtoken_offsets

            if self.dict:
                chunks = self.dict.split(input_tokens)
                _input_tokens, _input_ids, _subtoken_offsets = [
                    self.cls_token
                ], [tokenizer.cls_token_id], []
                _offset = 0
                custom_words = sample['custom_words'] = []
                for chunk in chunks:
                    if isinstance(chunk, str):
                        tokens, ids, offsets = tokenize_str(
                            chunk, add_special_tokens=False)
                    else:
                        begin, end, label = chunk
                        custom_words.append(chunk)
                        if isinstance(label, list):
                            tokens, ids, offsets, delta = [], [], [], 0
                            for token in label:
                                _tokens, _ids, _offsets = tokenize_str(
                                    token, add_special_tokens=False)
                                tokens.extend(_tokens)
                                ids.extend(_ids)
                                offsets.append((_offsets[0][0] + delta,
                                                _offsets[-1][-1] + delta))
                                delta = offsets[-1][-1]
                        else:
                            tokens, ids, offsets = tokenize_str(
                                input_tokens[begin:end],
                                add_special_tokens=False)
                            offsets = [(offsets[0][0], offsets[-1][-1])]
                    _input_tokens.extend(tokens)
                    _input_ids.extend(ids)
                    _subtoken_offsets.extend(
                        (x[0] + _offset, x[1] + _offset) for x in offsets)
                    _offset = _subtoken_offsets[-1][-1]
                subtoken_offsets = _subtoken_offsets
                input_tokens = _input_tokens + [self.sep_token]
                input_ids = _input_ids + [tokenizer.sep_token_id]
            else:
                input_tokens, input_ids, subtoken_offsets = tokenize_str(
                    input_tokens, add_special_tokens=True)

            if self.ret_subtokens:
                sample[f'{self.input_key}_subtoken_offsets'] = subtoken_offsets

        cls_is_bos = self.cls_is_bos
        if cls_is_bos is None:
            cls_is_bos = input_tokens[0] == BOS
        sep_is_eos = self.sep_is_eos
        if sep_is_eos is None:
            sep_is_eos = input_tokens[-1] == EOS
        if self.strip_cls_sep:
            if cls_is_bos:
                input_tokens = input_tokens[1:]
            if sep_is_eos:
                input_tokens = input_tokens[:-1]
        if not self.ret_mask_and_type:  # only need input_ids and token_span, use a light version
            if input_is_str:
                prefix_mask = self._init_prefix_mask(input_ids)
            else:
                if input_tokens:
                    encodings = tokenizer.batch_encode_plus(
                        input_tokens,
                        return_offsets_mapping=tokenizer.is_fast
                        and self.ret_subtokens,
                        add_special_tokens=False)
                else:
                    encodings = SerializableDict()
                    encodings.data = {'input_ids': []}
                subtoken_ids_per_token = encodings.data['input_ids']
                # Some tokens get stripped out
                subtoken_ids_per_token = [
                    ids if ids else [tokenizer.unk_token_id]
                    for ids in subtoken_ids_per_token
                ]
                input_ids = sum(subtoken_ids_per_token,
                                [tokenizer.cls_token_id])
                if self.sep_is_eos is None:
                    # None means to check whether sep is at the tail or between tokens
                    if sep_is_eos:
                        input_ids += [tokenizer.sep_token_id]
                    elif tokenizer.sep_token_id not in input_ids:
                        input_ids += [tokenizer.sep_token_id]
                else:
                    input_ids += [tokenizer.sep_token_id]
                # else self.sep_is_eos == False means sep is between tokens and don't bother to check

                if self.ret_subtokens:
                    prefix_mask = self._init_prefix_mask(input_ids)
                    ret_token_span = bool(prefix_mask)
                else:
                    prefix_mask = [False] * len(input_ids)
                    offset = 1
                    for _subtokens in subtoken_ids_per_token:
                        prefix_mask[offset] = True
                        offset += len(_subtokens)
                if self.ret_subtokens:
                    subtoken_offsets = []
                    for token, encoding in zip(input_tokens,
                                               encodings.encodings):
                        if encoding.offsets:
                            subtoken_offsets.append(encoding.offsets)
                        else:
                            subtoken_offsets.append([(0, len(token))])
                    if self.ret_subtokens_group:
                        sample[
                            f'{self.input_key}_subtoken_offsets_group'] = subtoken_offsets
                    sample[f'{self.input_key}_subtoken_offsets'] = sum(
                        subtoken_offsets, [])
        else:
            input_ids, attention_mask, token_type_ids, prefix_mask = \
                convert_examples_to_features(input_tokens,
                                             None,
                                             tokenizer,
                                             cls_token_at_end=self.cls_token_at_end,
                                             # xlnet has a cls token at the end
                                             cls_token=tokenizer.cls_token,
                                             cls_token_segment_id=self.cls_token_segment_id,
                                             sep_token=self.sep_token,
                                             sep_token_extra=self.sep_token_extra,
                                             # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
                                             pad_on_left=self.pad_on_left,
                                             # pad on the left for xlnet
                                             pad_token_id=self.pad_token_id,
                                             pad_token_segment_id=self.pad_token_segment_id,
                                             pad_token_label_id=0,
                                             do_padding=self.do_padding)
        if len(input_ids) > self.max_seq_length:
            if self.truncate_long_sequences:
                # raise SequenceTooLong(
                #     f'Input tokens {input_tokens} exceed the max sequence length of {self.max_seq_length - 2}. '
                #     f'For sequence tasks, truncate_long_sequences = True is not supported.'
                #     f'You are recommended to split your long text into several sentences within '
                #     f'{self.max_seq_length - 2} tokens beforehand. '
                #     f'Or simply set truncate_long_sequences = False to enable sliding window.')
                input_ids = input_ids[:self.max_seq_length]
                prefix_mask = prefix_mask[:self.max_seq_length]
                warnings.warn(
                    f'Input tokens {input_tokens} exceed the max sequence length of {self.max_seq_length - 2}. '
                    f'The exceeded part will be truncated and ignored. '
                    f'You are recommended to split your long text into several sentences within '
                    f'{self.max_seq_length - 2} tokens beforehand.'
                    f'Or simply set truncate_long_sequences = False to enable sliding window.'
                )
            else:
                input_ids = self.sliding_window(
                    input_ids, input_ids[-1] == tokenizer.sep_token_id)
        if prefix_mask:
            if cls_is_bos:
                prefix_mask[0] = True
            if sep_is_eos:
                prefix_mask[-1] = True
        outputs = [input_ids]
        if self.ret_mask_and_type:
            # noinspection PyUnboundLocalVariable
            outputs += [attention_mask, token_type_ids]
        if self.ret_prefix_mask:
            outputs += [prefix_mask]
        if ret_token_span and prefix_mask:
            if cls_is_bos:
                token_span = [[0]]
            else:
                token_span = []
            offset = 1
            span = []
            for mask in prefix_mask[1:len(prefix_mask) if sep_is_eos is None
                                    else -1]:  # skip [CLS] and [SEP]
                if mask and span:
                    token_span.append(span)
                    span = []
                span.append(offset)
                offset += 1
            if span:
                token_span.append(span)
            if sep_is_eos:
                assert offset == len(prefix_mask) - 1
                token_span.append([offset])
            outputs.append(token_span)
        for k, v in zip(self.output_key, outputs):
            sample[k] = v
        return sample