Beispiel #1
0
 def build_callbacks(self, save_dir, logger, **kwargs):
     metrics = kwargs.get('metrics', 'accuracy')
     if isinstance(metrics, (list, tuple)):
         metrics = metrics[-1]
     monitor = f'val_{metrics}'
     checkpoint = tf.keras.callbacks.ModelCheckpoint(
         os.path.join(save_dir, 'model.h5'),
         # verbose=1,
         monitor=monitor, save_best_only=True,
         mode='max',
         save_weights_only=True)
     logger.debug(f'Monitor {checkpoint.monitor} for checkpoint')
     tensorboard_callback = tf.keras.callbacks.TensorBoard(
         log_dir=io_util.makedirs(io_util.path_join(save_dir, 'logs')))
     csv_logger = FineCSVLogger(os.path.join(save_dir, 'train.log'), separator=' | ', append=True)
     callbacks = [checkpoint, tensorboard_callback, csv_logger]
     lr_decay_per_epoch = self.config.get('lr_decay_per_epoch', None)
     if lr_decay_per_epoch:
         learning_rate = self.model.optimizer.get_config().get('learning_rate', None)
         if not learning_rate:
             logger.warning('Learning rate decay not supported for optimizer={}'.format(repr(self.model.optimizer)))
         else:
             logger.debug(f'Created LearningRateScheduler with lr_decay_per_epoch={lr_decay_per_epoch}')
             callbacks.append(tf.keras.callbacks.LearningRateScheduler(
                 lambda epoch: learning_rate / (1 + lr_decay_per_epoch * epoch)))
     anneal_factor = self.config.get('anneal_factor', None)
     if anneal_factor:
         callbacks.append(tf.keras.callbacks.ReduceLROnPlateau(factor=anneal_factor,
                                                               patience=self.config.get('anneal_patience', 10)))
     early_stopping_patience = self.config.get('early_stopping_patience', None)
     if early_stopping_patience:
         callbacks.append(tf.keras.callbacks.EarlyStopping(monitor=monitor, mode='max',
                                                           verbose=1,
                                                           patience=early_stopping_patience))
     return callbacks
Beispiel #2
0
def load_word2vec(path,
                  delimiter=' ',
                  cache=True) -> Tuple[Dict[str, np.ndarray], int]:
    realpath = get_resource(path)
    binpath = replace_ext(realpath, '.pkl')
    if cache:
        try:
            word2vec, dim = load_pickle(binpath)
            logger.debug(f'Loaded {binpath}')
            return word2vec, dim
        except IOError:
            pass

    dim = None
    word2vec = dict()
    with open(realpath, encoding='utf-8', errors='ignore') as f:
        for idx, line in enumerate(f):
            line = line.rstrip().split(delimiter)
            if len(line) > 2:
                if dim is None:
                    dim = len(line)
                else:
                    if len(line) != dim:
                        logger.warning(
                            '{}#{} length mismatches with {}'.format(
                                path, idx + 1, dim))
                        continue
                word, vec = line[0], line[1:]
                word2vec[word] = np.array(vec, dtype=np.float32)
    dim -= 1
    if cache:
        save_pickle((word2vec, dim), binpath)
        logger.debug(f'Cached {binpath}')
    return word2vec, dim
Beispiel #3
0
def load_word2vec(path, delimiter=' ', cache=True) -> Tuple[Dict[str, np.ndarray], int]:
    realpath = get_resource(path)
    binpath = replace_ext(realpath, '.pkl')
    if cache:
        try:
            flash('Loading word2vec from cache [blink][yellow]...[/yellow][/blink]')
            word2vec, dim = load_pickle(binpath)
            flash('')
            return word2vec, dim
        except IOError:
            pass

    dim = None
    word2vec = dict()
    f = TimingFileIterator(realpath)
    for idx, line in enumerate(f):
        f.log('Loading word2vec from text file [blink][yellow]...[/yellow][/blink]')
        line = line.rstrip().split(delimiter)
        if len(line) > 2:
            if dim is None:
                dim = len(line)
            else:
                if len(line) != dim:
                    logger.warning('{}#{} length mismatches with {}'.format(path, idx + 1, dim))
                    continue
            word, vec = line[0], line[1:]
            word2vec[word] = np.array(vec, dtype=np.float32)
    dim -= 1
    if cache:
        flash('Caching word2vec [blink][yellow]...[/yellow][/blink]')
        save_pickle((word2vec, dim), binpath)
        flash('')
    return word2vec, dim
Beispiel #4
0
 def batched_inputs_to_batches(self, corpus, indices, shuffle):
     use_pos = self.use_pos
     raw_batch = [[], [], [], []] if use_pos else [[], [], []]
     max_len = len(max([corpus[i] for i in indices], key=len))
     for idx in indices:
         arc = np.zeros((max_len, max_len), dtype=np.bool)
         rel = np.zeros((max_len, max_len), dtype=np.int64)
         for b in raw_batch[:2]:
             b.append([])
         for m, cells in enumerate(corpus[idx]):
             if use_pos:
                 for b, c, v in zip(raw_batch, cells,
                                    [self.form_vocab, self.cpos_vocab]):
                     b[-1].append(v.get_idx_without_add(c))
             else:
                 for b, c, v in zip(raw_batch, cells, [self.form_vocab]):
                     b[-1].append(v.get_idx_without_add(c))
             for n, r in zip(cells[-2], cells[-1]):
                 arc[m, n] = True
                 rid = self.rel_vocab.get_idx_without_add(r)
                 if rid is None:
                     logger.warning(
                         f'Relation OOV: {r} not exists in train')
                     continue
                 rel[m, n] = rid
         raw_batch[-2].append(arc)
         raw_batch[-1].append(rel)
     batch = []
     for b, v in zip(raw_batch, [self.form_vocab, self.cpos_vocab]):
         b = tf.keras.preprocessing.sequence.pad_sequences(
             b, padding='post', value=v.safe_pad_token_idx, dtype='int64')
         batch.append(b)
     batch += raw_batch[2:]
     assert len(batch) == 4
     yield (batch[0], batch[1]), (batch[2], batch[3])
Beispiel #5
0
 def generator():
     # custom bucketing, load corpus into memory
     corpus = list(x for x in (samples() if callable(samples) else samples))
     lengths = [1 + len(i) for i in corpus]
     if len(corpus) < 32:
         n_buckets = 1
     else:
         n_buckets = min(self.config.n_buckets, len(corpus))
     buckets = dict(zip(*kmeans(lengths, n_buckets)))
     sizes, buckets = zip(*[
         (size, bucket) for size, bucket in buckets.items()
     ])
     # the number of chunks in each bucket, which is clipped by
     # range [1, len(bucket)]
     chunks = [min(len(bucket), max(round(size * len(bucket) / batch_size), 1)) for size, bucket in
               zip(sizes, buckets)]
     range_fn = randperm if shuffle else arange
     for i in tolist(range_fn(len(buckets))):
         split_sizes = [(len(buckets[i]) - j - 1) // chunks[i] + 1
                        for j in range(chunks[i])]
         for batch_indices in tf.split(range_fn(len(buckets[i])), split_sizes):
             indices = [buckets[i][j] for j in tolist(batch_indices)]
             raw_batch = [[], [], [], []]
             max_len = len(max([corpus[i] for i in indices], key=len))
             for idx in indices:
                 arc = np.zeros((max_len, max_len), dtype=np.bool)
                 rel = np.zeros((max_len, max_len), dtype=np.int64)
                 for b in raw_batch[:2]:
                     b.append([])
                 for m, cells in enumerate(corpus[idx]):
                     for b, c, v in zip(raw_batch, cells,
                                        [self.form_vocab, self.cpos_vocab]):
                         b[-1].append(v.get_idx_without_add(c))
                     for n, r in zip(cells[2], cells[3]):
                         arc[m, n] = True
                         rid = self.rel_vocab.get_idx_without_add(r)
                         if rid is None:
                             logger.warning(f'Relation OOV: {r} not exists in train')
                             continue
                         rel[m, n] = rid
                 raw_batch[-2].append(arc)
                 raw_batch[-1].append(rel)
             batch = []
             for b, v in zip(raw_batch, [self.form_vocab, self.cpos_vocab]):
                 b = tf.keras.preprocessing.sequence.pad_sequences(b, padding='post',
                                                                   value=v.safe_pad_token_idx,
                                                                   dtype='int64')
                 batch.append(b)
             batch += raw_batch[2:]
             assert len(batch) == 4
             yield (batch[0], batch[1]), (batch[2], batch[3])
Beispiel #6
0
    def batched_inputs_to_batches(self, corpus, indices, shuffle=False):
        """
        Convert batched inputs to batches of samples

        Parameters
        ----------
        corpus : list
            A list of inputs
        indices : list
            A list of indices, each list belongs to a batch

        Returns
        -------
        None

        Yields
        -------
        tuple
            tuple of tf.Tensor
        """
        raw_batch = [[], [], [], []]
        max_len = len(max([corpus[i] for i in indices], key=len))
        for idx in indices:
            arc = np.zeros((max_len, max_len), dtype=np.bool)
            rel = np.zeros((max_len, max_len), dtype=np.int64)
            for b in raw_batch[:2]:
                b.append([])
            for m, cells in enumerate(corpus[idx]):
                for b, c, v in zip(raw_batch, cells,
                                   [self.form_vocab, self.cpos_vocab]):
                    b[-1].append(v.get_idx_without_add(c))
                for n, r in zip(cells[2], cells[3]):
                    arc[m, n] = True
                    rid = self.rel_vocab.get_idx_without_add(r)
                    if rid is None:
                        logger.warning(f'Relation OOV: {r} not exists in train')
                        continue
                    rel[m, n] = rid
            raw_batch[-2].append(arc)
            raw_batch[-1].append(rel)
        batch = []
        for b, v in zip(raw_batch, [self.form_vocab, self.cpos_vocab]):
            b = tf.keras.preprocessing.sequence.pad_sequences(b, padding='post',
                                                              value=v.safe_pad_token_idx,
                                                              dtype='int64')
            batch.append(b)
        batch += raw_batch[2:]
        assert len(batch) == 4
        yield (batch[0], batch[1]), (batch[2], batch[3])
Beispiel #7
0
    def inputs_to_samples(self, inputs, gold=False):
        tokenizer = self.tokenizer
        max_length = self.config.max_length
        num_features = None
        pad_token = None if self.label_vocab.mutable else tokenizer.convert_tokens_to_ids(['[PAD]'])[0]
        for (X, Y) in super().inputs_to_samples(inputs, gold):
            if self.label_vocab.mutable:
                yield None, Y
                continue
            if isinstance(X, str):
                X = (X,)
            if num_features is None:
                num_features = self.config.num_features
            assert num_features == len(X), f'Numbers of features {num_features} ' \
                                           f'inconsistent with current {len(X)}={X}'
            text_a = X[0]
            text_b = X[1] if len(X) > 1 else None
            tokens_a = self.tokenizer.tokenize(text_a)
            tokens_b = self.tokenizer.tokenize(text_b) if text_b else None
            tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
            segment_ids = [0] * len(tokens)
            if tokens_b:
                tokens += tokens_b
                segment_ids += [1] * len(tokens_b)
            token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
            attention_mask = [1] * len(token_ids)
            diff = max_length - len(token_ids)
            if diff < 0:
                logger.warning(
                    f'Input tokens {tokens} exceed the max sequence length of {max_length - 2}. '
                    f'The exceeded part will be truncated and ignored. '
                    f'You are recommended to split your long text into several sentences within '
                    f'{max_length - 2} tokens beforehand.')
                token_ids = token_ids[:max_length]
                attention_mask = attention_mask[:max_length]
                segment_ids = segment_ids[:max_length]
            elif diff > 0:
                token_ids += [pad_token] * diff
                attention_mask += [0] * diff
                segment_ids += [0] * diff

            assert len(token_ids) == max_length, "Error with input length {} vs {}".format(len(token_ids), max_length)
            assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask),
                                                                                                max_length)
            assert len(segment_ids) == max_length, "Error with input length {} vs {}".format(len(segment_ids),
                                                                                             max_length)

            label = Y
            yield (token_ids, attention_mask, segment_ids), label
Beispiel #8
0
 def x_to_idx(self, x) -> Union[tf.Tensor, Tuple]:
     logger.warning(
         'TableTransform can not map x to idx. Please override x_to_idx')
     return x
Beispiel #9
0
def convert_examples_to_features(words,
                                 max_seq_length,
                                 tokenizer,
                                 labels=None,
                                 label_map=None,
                                 cls_token_at_end=False,
                                 cls_token="[CLS]",
                                 cls_token_segment_id=1,
                                 sep_token="[SEP]",
                                 sep_token_extra=False,
                                 pad_on_left=False,
                                 pad_token_id=0,
                                 pad_token_segment_id=0,
                                 pad_token_label_id=0,
                                 sequence_a_segment_id=0,
                                 mask_padding_with_zero=True,
                                 unk_token='[UNK]',
                                 do_padding=True):
    """Loads a data file into a list of `InputBatch`s
        `cls_token_at_end` define the location of the CLS token:
            - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
            - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
        `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)

    Args:
      words: 
      max_seq_length: 
      tokenizer: 
      labels:  (Default value = None)
      label_map:  (Default value = None)
      cls_token_at_end:  (Default value = False)
      cls_token:  (Default value = "[CLS]")
      cls_token_segment_id:  (Default value = 1)
      sep_token:  (Default value = "[SEP]")
      sep_token_extra:  (Default value = False)
      pad_on_left:  (Default value = False)
      pad_token_id:  (Default value = 0)
      pad_token_segment_id:  (Default value = 0)
      pad_token_label_id:  (Default value = 0)
      sequence_a_segment_id:  (Default value = 0)
      mask_padding_with_zero:  (Default value = True)
      unk_token:  (Default value = '[UNK]')
      do_padding:  (Default value = True)

    Returns:

    """
    args = locals()
    if not labels:
        labels = words
        pad_token_label_id = False

    tokens = []
    label_ids = []
    for word, label in zip(words, labels):
        word_tokens = tokenizer.tokenize(word)
        if not word_tokens:
            # some wired chars cause the tagger to return empty list
            word_tokens = [unk_token] * len(word)
        tokens.extend(word_tokens)
        # Use the real label id for the first token of the word, and padding ids for the remaining tokens
        label_ids.extend([label_map[label] if label_map else True] +
                         [pad_token_label_id] * (len(word_tokens) - 1))

    # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
    special_tokens_count = 3 if sep_token_extra else 2
    if len(tokens) > max_seq_length - special_tokens_count:
        logger.warning(
            f'Input tokens {words} exceed the max sequence length of {max_seq_length - special_tokens_count}. '
            f'The exceeded part will be truncated and ignored. '
            f'You are recommended to split your long text into several sentences within '
            f'{max_seq_length - special_tokens_count} tokens beforehand.')
        tokens = tokens[:(max_seq_length - special_tokens_count)]
        label_ids = label_ids[:(max_seq_length - special_tokens_count)]

    # The convention in BERT is:
    # (a) For sequence pairs:
    #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
    #  token_type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1
    # (b) For single sequences:
    #  tokens:   [CLS] the dog is hairy . [SEP]
    #  token_type_ids:   0   0   0   0  0     0   0
    #
    # Where "token_type_ids" are used to indicate whether this is the first
    # sequence or the second sequence. The embedding vectors for `type=0` and
    # `type=1` were learned during pre-training and are added to the wordpiece
    # embedding vector (and position vector). This is not *strictly* necessary
    # since the [SEP] token unambiguously separates the sequences, but it makes
    # it easier for the model to learn the concept of sequences.
    #
    # For classification tasks, the first vector (corresponding to [CLS]) is
    # used as as the "sentence vector". Note that this only makes sense because
    # the entire model is fine-tuned.
    tokens += [sep_token]
    label_ids += [pad_token_label_id]
    if sep_token_extra:
        # roberta uses an extra separator b/w pairs of sentences
        tokens += [sep_token]
        label_ids += [pad_token_label_id]
    segment_ids = [sequence_a_segment_id] * len(tokens)

    if cls_token_at_end:
        tokens += [cls_token]
        label_ids += [pad_token_label_id]
        segment_ids += [cls_token_segment_id]
    else:
        tokens = [cls_token] + tokens
        label_ids = [pad_token_label_id] + label_ids
        segment_ids = [cls_token_segment_id] + segment_ids

    input_ids = tokenizer.convert_tokens_to_ids(tokens)

    # The mask has 1 for real tokens and 0 for padding tokens. Only real
    # tokens are attended to.
    input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

    if do_padding:
        # Zero-pad up to the sequence length.
        padding_length = max_seq_length - len(input_ids)
        if pad_on_left:
            input_ids = ([pad_token_id] * padding_length) + input_ids
            input_mask = ([0 if mask_padding_with_zero else 1] *
                          padding_length) + input_mask
            segment_ids = ([pad_token_segment_id] *
                           padding_length) + segment_ids
            label_ids = ([pad_token_label_id] * padding_length) + label_ids
        else:
            input_ids += [pad_token_id] * padding_length
            input_mask += [0 if mask_padding_with_zero else 1] * padding_length
            segment_ids += [pad_token_segment_id] * padding_length
            label_ids += [pad_token_label_id] * padding_length

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        assert len(label_ids) == max_seq_length, f'failed for:\n {args}'
    else:
        assert len(
            set(
                len(x)
                for x in [input_ids, input_mask, segment_ids, label_ids])) == 1
    return input_ids, input_mask, segment_ids, label_ids
Beispiel #10
0
    def batched_inputs_to_batches(self, corpus, indices, shuffle):
        use_pos = self.use_pos
        if use_pos:
            raw_batch = [[], [], [], []]
        else:
            raw_batch = [[], [], []]
        if self.graph:
            max_len = len(max([corpus[i] for i in indices], key=len))
            for idx in indices:
                arc = np.zeros((max_len, max_len), dtype=np.bool)
                rel = np.zeros((max_len, max_len), dtype=np.int64)
                for b in raw_batch[:2 if use_pos else 1]:
                    b.append([])
                for m, cells in enumerate(corpus[idx]):
                    if use_pos:
                        for b, c, v in zip(raw_batch, cells, [None, self.cpos_vocab]):
                            b[-1].append(v.get_idx_without_add(c) if v else c)
                    else:
                        for b, c, v in zip(raw_batch, cells, [None]):
                            b[-1].append(c)
                    for n, r in zip(cells[-2], cells[-1]):
                        arc[m, n] = True
                        rid = self.rel_vocab.get_idx_without_add(r)
                        if rid is None:
                            logger.warning(f'Relation OOV: {r} not exists in train')
                            continue
                        rel[m, n] = rid
                raw_batch[-2].append(arc)
                raw_batch[-1].append(rel)
        else:
            for idx in indices:
                for s in raw_batch:
                    s.append([])
                for cells in corpus[idx]:
                    if use_pos:
                        for s, c, v in zip(raw_batch, cells, [None, self.cpos_vocab, None, self.rel_vocab]):
                            s[-1].append(v.get_idx_without_add(c) if v else c)
                    else:
                        for s, c, v in zip(raw_batch, cells, [None, None, self.rel_vocab]):
                            s[-1].append(v.get_idx_without_add(c) if v else c)

        # Transformer tokenizing
        config = self.transformer_config
        tokenizer = self.tokenizer
        xlnet = config_is(config, 'xlnet')
        roberta = config_is(config, 'roberta')
        pad_token = tokenizer.pad_token
        pad_token_id = tokenizer.convert_tokens_to_ids([pad_token])[0]
        cls_token = tokenizer.cls_token
        sep_token = tokenizer.sep_token
        max_seq_length = self.config.max_seq_length
        batch_forms = []
        batch_input_ids = []
        batch_input_mask = []
        batch_prefix_offset = []
        mask_p = self.mask_p
        if mask_p:
            batch_masked_offsets = []
            mask_token_id = tokenizer.mask_token_id
        for sent_idx, sent in enumerate(raw_batch[0]):
            batch_forms.append([self.form_vocab.get_idx_without_add(token) for token in sent])
            sent = adjust_tokens_for_transformers(sent)
            sent = sent[1:]  # remove <root> use [CLS] instead
            pad_label_idx = self.form_vocab.pad_idx
            input_ids, input_mask, segment_ids, prefix_mask = \
                convert_examples_to_features(sent,
                                             max_seq_length,
                                             tokenizer,
                                             cls_token_at_end=xlnet,
                                             # xlnet has a cls token at the end
                                             cls_token=cls_token,
                                             cls_token_segment_id=2 if xlnet else 0,
                                             sep_token=sep_token,
                                             sep_token_extra=roberta,
                                             # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
                                             pad_on_left=xlnet,
                                             # pad on the left for xlnet
                                             pad_token_id=pad_token_id,
                                             pad_token_segment_id=4 if xlnet else 0,
                                             pad_token_label_id=pad_label_idx,
                                             do_padding=False)
            num_masks = sum(prefix_mask)
            # assert len(sent) == num_masks  # each token has a True subtoken
            if num_masks < len(sent):  # long sent gets truncated, +1 for root
                batch_forms[-1] = batch_forms[-1][:num_masks + 1]  # form
                raw_batch[-1][sent_idx] = raw_batch[-1][sent_idx][:num_masks + 1]  # head
                raw_batch[-2][sent_idx] = raw_batch[-2][sent_idx][:num_masks + 1]  # rel
                raw_batch[-3][sent_idx] = raw_batch[-3][sent_idx][:num_masks + 1]  # pos
            prefix_mask[0] = True  # <root> is now [CLS]
            prefix_offset = [idx for idx, m in enumerate(prefix_mask) if m]
            batch_input_ids.append(input_ids)
            batch_input_mask.append(input_mask)
            batch_prefix_offset.append(prefix_offset)
            if mask_p:
                if shuffle:
                    size = int(np.ceil(mask_p * len(prefix_offset[1:])))  # never mask [CLS]
                    mask_offsets = np.random.choice(np.arange(1, len(prefix_offset)), size, replace=False)
                    for offset in sorted(mask_offsets):
                        assert 0 < offset < len(input_ids)
                        # mask_word = raw_batch[0][sent_idx][offset]
                        # mask_prefix = tokenizer.convert_ids_to_tokens([input_ids[prefix_offset[offset]]])[0]
                        # assert mask_word.startswith(mask_prefix) or mask_prefix.startswith(
                        #     mask_word) or mask_prefix == "'", \
                        #     f'word {mask_word} prefix {mask_prefix} not match'  # could vs couldn
                        # mask_offsets.append(input_ids[offset]) # subword token
                        # mask_offsets.append(offset)  # form token
                        input_ids[prefix_offset[offset]] = mask_token_id  # mask prefix
                        # whole word masking, mask the rest of the word
                        for i in range(prefix_offset[offset] + 1, len(input_ids) - 1):
                            if prefix_mask[i]:
                                break
                            input_ids[i] = mask_token_id

                    batch_masked_offsets.append(sorted(mask_offsets))
                else:
                    batch_masked_offsets.append([0])  # No masking in prediction

        batch_forms = tf.keras.preprocessing.sequence.pad_sequences(batch_forms, padding='post',
                                                                    value=self.form_vocab.safe_pad_token_idx,
                                                                    dtype='int64')
        batch_input_ids = tf.keras.preprocessing.sequence.pad_sequences(batch_input_ids, padding='post',
                                                                        value=pad_token_id,
                                                                        dtype='int64')
        batch_input_mask = tf.keras.preprocessing.sequence.pad_sequences(batch_input_mask, padding='post',
                                                                         value=0,
                                                                         dtype='int64')
        batch_prefix_offset = tf.keras.preprocessing.sequence.pad_sequences(batch_prefix_offset, padding='post',
                                                                            value=0,
                                                                            dtype='int64')
        batch_heads = tf.keras.preprocessing.sequence.pad_sequences(raw_batch[-2], padding='post',
                                                                    value=0,
                                                                    dtype='int64')
        batch_rels = tf.keras.preprocessing.sequence.pad_sequences(raw_batch[-1], padding='post',
                                                                   value=self.rel_vocab.safe_pad_token_idx,
                                                                   dtype='int64')
        if mask_p:
            batch_masked_offsets = tf.keras.preprocessing.sequence.pad_sequences(batch_masked_offsets, padding='post',
                                                                                 value=pad_token_id,
                                                                                 dtype='int64')
        feats = (tf.constant(batch_input_ids, dtype='int64'), tf.constant(batch_input_mask, dtype='int64'),
                 tf.constant(batch_prefix_offset))
        if use_pos:
            batch_pos = tf.keras.preprocessing.sequence.pad_sequences(raw_batch[1], padding='post',
                                                                      value=self.cpos_vocab.safe_pad_token_idx,
                                                                      dtype='int64')
            feats += (batch_pos,)
        yield (batch_forms, feats), \
              (batch_heads, batch_rels, batch_masked_offsets) if mask_p else (batch_heads, batch_rels)