def _infer_pad_shape(nested_lists): """Return the minimal tensor shape which could contain the input data.""" yield len(nested_lists) while nested_lists and all(should_iter(i) for i in nested_lists): # pad shape to be multiple of 8 when fp16 enabled yield precision.pad_length(max(len(nested) for nested in nested_lists)) nested_lists = list(itertools.chain.from_iterable(nested_lists))
def tensorize(self, batch, pad_token=0): bytes, token_lengths, byte_lengths = zip(*batch) # Set bytes shape because byte length should always be `max_byte_len` no # matter how long the bytes in the batch are. pad_shape = ( len(batch), precision.pad_length(max(len(l) for l in byte_lengths)), self.max_byte_len, ) return ( pad_and_tensorize(bytes, pad_shape=pad_shape, pad_token=pad_token), pad_and_tensorize(token_lengths), pad_and_tensorize(byte_lengths), )
def pad_length(self, n): """ Override to make pad_length to be multiple of 8 to support fp16 training """ return precision.pad_length(n)