Ejemplo n.º 1
0
def extract_patches(inputs: tf.Tensor, size: int) -> tf.Tensor:
    """Extract overlapping patches from a batch of 1D tensors.

  Args:
    inputs: Tensor with dimensions [batch, x].
    size: number of elements to include in each patch.

  Returns:
    Tensor with dimensions [batch, x, size].
  """
    padded_inputs = layers.pad_periodic(inputs[..., tf.newaxis],
                                        size - 1,
                                        center=True)
    extracted = tf.extract_image_patches(padded_inputs[..., tf.newaxis],
                                         ksizes=[1, size, 1, 1],
                                         strides=[1, 1, 1, 1],
                                         rates=[1, 1, 1, 1],
                                         padding='VALID')
    return tf.squeeze(extracted, axis=2)
def pad_periodic_1d(inputs, padding, center=False):
  padded_inputs = inputs[tf.newaxis, :, tf.newaxis]
  padded_outputs = layers.pad_periodic(padded_inputs, padding, center)
  return tf.squeeze(padded_outputs, axis=(0, 2))