def __call__(self, step):
     starting_iteration = self.steps_per_epoch * self.start_epoch
     starting_iteration = tf.cast(starting_iteration, self.dtype)
     global_step = tf.cast(step, self.dtype)
     recomp_iteration = global_step - starting_iteration + 1.
     decayed_coeff = self.coeff_scheduler(recomp_iteration)
     # This is an autograph-friendly alternative to checking Tensorflow booleans
     # in eager mode.
     scale = tf.minimum(
         tf.maximum(tf.cast(recomp_iteration, self.dtype), 0.), 1.)
     return scale * decayed_coeff
 def while_loop_body(iteration, distribution, inactive, old_inactive):
     """Performs one iteration of the projection."""
     del old_inactive  # Needed by the condition, but not the body (for lint).
     iteration += 1
     scale = (1.0 - tf.reduce_sum(
         distribution, axis=0, keepdims=True)) / tf.maximum(
             1.0, tf.reduce_sum(inactive, axis=0, keepdims=True))
     distribution = distribution + (scale * inactive)
     new_inactive = tf.cast(distribution > 0, distribution.dtype)
     distribution = distribution * new_inactive
     return (iteration, distribution, new_inactive, inactive)
Beispiel #3
0
 def parameterized_infection_dynamics(_, previous_state, parameters):
     new_infections = tfd.Poisson(
         parameters['infection_rate'] * previous_state['infected'] *
         previous_state['susceptible'] / population_size)
     new_recoveries = tfd.Poisson(previous_state['infected'] *
                                  parameters['recovery_rate'])
     return tfd.JointDistributionNamed({
         'new_infections':
         new_infections,
         'new_recoveries':
         new_recoveries,
         'susceptible':
         lambda new_infections: tfd.Deterministic(  # pylint: disable=g-long-lambda
             tf.maximum(0., previous_state['susceptible'] -
                        new_infections)),
         'infected':
         lambda new_infections, new_recoveries: tfd.Deterministic(  # pylint: disable=g-long-lambda
             tf.maximum(0., (previous_state['infected'] + new_infections
                             - new_recoveries)))
     })
Beispiel #4
0
def random_brightness(image, max_delta, impl='simclrv2'):
    """A multiplicative vs additive change of brightness."""
    if impl == 'simclrv2':
        factor = tf.random.uniform([], tf.maximum(1.0 - max_delta, 0),
                                   1.0 + max_delta)
        image = image * factor
    elif impl == 'simclrv1':
        image = tf.image.random_brightness(image, max_delta=max_delta)
    else:
        raise ValueError('Unknown impl {} for random brightness.'.format(impl))
    return image
Beispiel #5
0
 def _log_unnormalized_prob(self, x):
     # The log-probability at negative points is always -inf.
     # Catch such x's and set the output value accordingly.
     safe_x = tf.maximum(x if self.interpolate_nondiscrete else tf.floor(x),
                         0.)
     y = safe_x * self.log_rate - tf.math.lgamma(1. + safe_x)
     is_supported = tf.broadcast_to(tf.equal(x, safe_x), tf.shape(input=y))
     neg_inf = tf.fill(tf.shape(input=y),
                       value=np.array(-np.inf,
                                      dtype=y.dtype.as_numpy_dtype))
     return tf.where(is_supported, y, neg_inf)
Beispiel #6
0
 def _sample_dataset(self, seed):
     dataset = dict(
         train_locations=self._train_locations,
         train_extents=self._train_extents,
     )
     prior_samples = self.prior_distribution().sample(seed=seed)
     observation_noise_dist = self._observation_noise_fn(
         prior_samples['log_intensity'])
     counts = observation_noise_dist.sample(seed=seed)
     dataset['train_counts'] = tf.maximum(1, tf.cast(counts, tf.int32))
     return dataset
Beispiel #7
0
    def _get_reinterpreted_batch_ndims(self,
                                       distribution_batch_shape_tensor=None):
        if self._static_reinterpreted_batch_ndims is not None:
            return self._static_reinterpreted_batch_ndims
        if self._reinterpreted_batch_ndims is not None:
            return tf.convert_to_tensor(self._reinterpreted_batch_ndims)

        if distribution_batch_shape_tensor is None:
            distribution_batch_shape_tensor = self.distribution.batch_shape_tensor(
            )
        return tf.maximum(0, tf.size(distribution_batch_shape_tensor) - 1)
Beispiel #8
0
def create_masks(inputs, target):
    """function that creates all masks for training/validation"""
    inputs = tf.cast(tf.math.equal(inputs, 0), tf.float32)
    target = tf.cast(tf.math.equal(target, 0), tf.float32)
    encoder_mask = inputs[:, tf.newaxis, tf.newaxis, :]
    decoder_mask = inputs[:, tf.newaxis, tf.newaxis, :]
    decoder_target_mask = target[:, tf.newaxis, tf.newaxis, :]
    look_ahead_mask = 1 - tf.linalg.band_part(
        tf.ones((target.shape[0], 1, target.shape[1], target.shape[1])), -1, 0)
    combined_mask = tf.maximum(decoder_target_mask, look_ahead_mask)
    return encoder_mask, combined_mask, decoder_mask
 def valid_transition_fn(_, particles):
     return tfd.JointDistributionNamedAutoBatched(
         {
             'sales':
             tfd.Poisson(10. * tf.ones_like(particles['inventory'])),
             'inventory':
             lambda sales: tfd.Deterministic(  # pylint: disable=g-long-lambda
                 tf.maximum(0., particles['inventory'] - sales))
         },
         batch_ndims=1,
         validate_args=True)
Beispiel #10
0
    def update_step(self, gradient, variable):
        """Update step given gradient and the associated model variable."""
        beta_1_power = None
        beta_2_power = None
        lr = tf.cast(self.learning_rate, variable.dtype)
        local_step = tf.cast(self.iterations + 1, variable.dtype)
        beta_1_power = tf.pow(tf.cast(self.beta_1, variable.dtype), local_step)
        beta_2_power = tf.pow(tf.cast(self.beta_2, variable.dtype), local_step)

        var_key = self._var_key(variable)
        m = self._momentums[self._index_dict[var_key]]
        v = self._velocities[self._index_dict[var_key]]

        alpha = lr * tf.sqrt(1 - beta_2_power) / (1 - beta_1_power)

        if isinstance(gradient, tf.IndexedSlices):
            # Sparse gradients.
            m.assign_add(-m * (1 - self.beta_1))
            m.scatter_add(
                tf.IndexedSlices(gradient.values * (1 - self.beta_1),
                                 gradient.indices))
            v.assign_add(-v * (1 - self.beta_2))
            v.scatter_add(
                tf.IndexedSlices(
                    tf.square(gradient.values) * (1 - self.beta_2),
                    gradient.indices,
                ))
            if self.amsgrad:
                v_hat = self._velocity_hats[self._index_dict[var_key]]
                v_hat.assign(tf.maximum(v_hat, v))
                v = v_hat
            variable.assign_sub((m * alpha) / (tf.sqrt(v) + self.epsilon))
        else:
            # Dense gradients.
            m.assign_add((gradient - m) * (1 - self.beta_1))
            v.assign_add((tf.square(gradient) - v) * (1 - self.beta_2))
            if self.amsgrad:
                v_hat = self._velocity_hats[self._index_dict[var_key]]
                v_hat.assign(tf.maximum(v_hat, v))
                v = v_hat
            variable.assign_sub((m * alpha) / (tf.sqrt(v) + self.epsilon))
def create_masks(inputs, target):
    """Function that creates all masks for training/validation"""
    enc_padding_mask = tf.cast(tf.math.equal(inputs, 0), tf.float32)
    enc_padding_mask = enc_padding_mask[:, tf.newaxis, tf.newaxis, :]
    dec_padding_mask = tf.cast(tf.math.equal(inputs, 0), tf.float32)
    dec_padding_mask = dec_padding_mask[:, tf.newaxis, tf.newaxis, :]
    size = target.shape[1]
    look_ahead_mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    dec_target_p_mask = tf.cast(tf.math.equal(target, 0), tf.float32)
    dec_target_p_mask = dec_target_p_mask[:, tf.newaxis, tf.newaxis, :]
    combined_mask = tf.maximum(dec_target_p_mask, look_ahead_mask)
    return (enc_padding_mask, combined_mask, dec_padding_mask)
 def transition_fn_no_batch_shape(_, particles):
     # Autobatched JD defaults to treating num_particles as event shape, but
     # we need it to be batch shape to get per-particle logprobs.
     return tfd.JointDistributionNamedAutoBatched(
         {
             'sales':
             tfd.Poisson(10. * tf.ones_like(particles['inventory'])),
             'inventory':
             lambda sales: tfd.Deterministic(  # pylint: disable=g-long-lambda
                 tf.maximum(0., particles['inventory'] - sales))
         },
         validate_args=True)
 def transition_fn_partial_batch_shape(_, particles):
     return tfd.JointDistributionNamed(
         # Using `Sample` ensures iid proposals for each particle, but not
         # per-particle log probs.
         {
             'sales':
             tfd.Sample(tfd.Poisson(10.), ps.shape(particles['sales'])),
             'inventory':
             lambda sales: tfd.Deterministic(  # pylint: disable=g-long-lambda
                 tf.maximum(0., particles['inventory'] - sales))
         },
         validate_args=True)
Beispiel #14
0
def random_resize(features,
                  scale=(0.5, 2.0),
                  ensure_small=None,
                  keys=("image", ),
                  methods=("bilinear", )):
    """Randomly resize the image and label by a uniformly sampled scale.

  Args:
      features: Input dictionary containing "image", "label", and other keys.
      scale: Output image and label will be scaled by a scale sampled uniformly
        at random in this range.
      ensure_small: Ignored if None. Else, if input image size * min(scale) is
        less than ensure_small, it will adjust the scale so that the output
        image is always at least as big as ensure_small. This is useful so that
        subsequent crop operations do not go out of range.
      keys: Keys to apply resize op to. Note that keys starting with prefix
        "label" will be resized using nearest neighbour.
      methods: Resize methods per key.

  Returns:
      features with randomly scaled "images" defined by keys.
  """
    if ensure_small is None:
        scale_min, scale_max = scale
    else:
        width, height = _get_image_size(features["image"], dynamic=True)
        scale_min = tf.maximum(
            ensure_small / tf.cast(width, dtype=tf.float32),
            ensure_small / tf.cast(height, dtype=tf.float32))
        scale_max = tf.maximum(scale[1], scale_min)

    scale_chosen = tf.random.uniform(shape=(),
                                     minval=scale_min,
                                     maxval=scale_max,
                                     dtype=tf.float32)
    width, height = _get_image_size(features["image"], dynamic=True)
    new_width = tf.cast(tf.cast(width, tf.float32) * scale_chosen, tf.int32)
    new_height = tf.cast(tf.cast(height, tf.float32) * scale_chosen, tf.int32)

    return resize(features, (new_height, new_width), keys, methods)
Beispiel #15
0
def pad_reflecting(x, padding_below, padding_above, axis):
    """Pads `x` with reflecting conditions above and/or below it along some axis.

  Pads `x` with reflecting conditions for `padding_below` entries below the
  tensor and `padding_above` entries above the tensor in the direction along
  `axis`. This is like using tf.pad(x, --, 'REFLECT'), except that this code
  allows for an unbounded number of reflections while tf.pad() only supports
  one reflection. Multiple reflections are necessary for for wavelet
  decompositions to guard against cases where the wavelet filters are larger
  than the input tensor along `axis`, which happens often at coarse scales.
  Note that "reflecting" boundary conditions are different from "symmetric"
  boundary conditions, in that it doesn't repeat the last element:
  reflect([A, B, C, D], 2) = [C, B, A, B, C, D, C, B]
  symmet.([A, B, C, D], 2) = [B, A, A, B, C, D, D, C]

  Args:
    x: The tensor to be padded with reflecting boundary conditions.
    padding_below: The number of elements being padded below the tensor.
    padding_above: The number of elements being padded above the tensor.
    axis: The axis in x in which padding will be performed.

  Returns:
    `x` padded according to `padding_below` and `padding_above` along `axis`
    with reflecting boundary conditions.
  """
    if not isinstance(padding_below, int):
        raise ValueError(
            'Expected `padding_below` of type int, but is of type {}'.format(
                type(padding_below)))
    if not isinstance(padding_above, int):
        raise ValueError(
            'Expected `padding_above` of type int, but is of type {}'.format(
                type(padding_above)))
    if not isinstance(axis, int):
        raise ValueError(
            'Expected `axis` of type int, but is of type {}'.format(
                type(axis)))
    if not (axis >= 0 and axis < len(x.shape)):
        raise ValueError('Expected `axis` in [0, {}], but is = {}'.format(
            len(x.shape) - 1, axis))

    if padding_below == 0 and padding_above == 0:
        return tf.convert_to_tensor(x)
    n = tf.shape(x)[axis]
    # `i' contains the indices of the output padded tensor in the frame of
    # reference of the input tensor.
    i = tf.range(-padding_below, n + padding_above, dtype=tf.int32)
    # `j` contains the indices of the input tensor corresponding to the output
    # padded tensor.
    i_mod = tf.math.mod(i, tf.maximum(1, 2 * (n - 1)))
    j = tf.minimum(2 * (n - 1) - i_mod, i_mod)
    return tf.gather(x, j, axis=axis)
Beispiel #16
0
 def call(self, inputs):
     inputs = tf.convert_to_tensor(inputs)
     if inputs.shape.rank == 1:
         inputs = tf.compat.v1.expand_dims(inputs, 1)
     # If the inputs are not floats, cast them to floats. This avoids issues
     # with int-float multiplication and division below.
     if inputs.dtype != K.floatx():
         inputs = tf.cast(inputs, K.floatx())
     # We need to reshape the mean and variance data to ensure that Tensorflow
     # broadcasts the data correctly.
     mean = tf.reshape(self.mean, self._broadcast_shape)
     variance = tf.reshape(self.variance, self._broadcast_shape)
     return ((inputs - mean) / tf.maximum(tf.sqrt(variance), K.epsilon()))
def pairwise_l2_distance(embs1, embs2):
    """Computes pairwise distances between all rows of embs1 and embs2."""
    norm1 = tf.reduce_sum(tf.square(embs1), 1)
    norm1 = tf.reshape(norm1, [-1, 1])
    norm2 = tf.reduce_sum(tf.square(embs2), 1)
    norm2 = tf.reshape(norm2, [1, -1])

    # Max to ensure matmul doesn't produce anything negative due to floating
    # point approximations.
    dist = tf.maximum(
        norm1 + norm2 - 2.0 * tf.matmul(embs1, embs2, False, True), 0.0)

    return dist
def logmap0(y, c):
    """Hyperbolic logarithmic map at zero in the Poincare ball model.

  Args:
    y: Tensor of size B x dimension representing hyperbolic points
    c: Tensor of size 1 representing the absolute hyperbolic curvature.

  Returns:
    Tensor of shape B x dimension.
  """
    sqrt_c = tf.sqrt(c)
    y_norm = tf.maximum(tf.norm(y, axis=-1, keepdims=True), MIN_NORM)
    return y / y_norm / sqrt_c * artanh(sqrt_c * y_norm)
Beispiel #19
0
        def _calc_loudness(audio, n_fft=2048, top_db=200.0, pmin=1e-20):
            """Perceptual loudness in tf, following librosa implementation."""
            librosa = tfds.core.lazy_imports.librosa
            log10 = lambda x: tf.math.log(x) / tf.math.log(10.0)

            spectra = tf.signal.stft(signals=audio,
                                     frame_length=n_fft,
                                     frame_step=int(_AUDIO_RATE //
                                                    _F0_AND_LOUDNESS_RATE),
                                     fft_length=n_fft,
                                     pad_end=True)

            power = tf.abs(spectra)**2.0
            power_db = 10.0 * log10(tf.maximum(pmin, power))
            power_db = tf.maximum(power_db, tf.reduce_max(power_db) - top_db)

            fft_frequencies = librosa.fft_frequencies(n_fft=n_fft)
            a_weighting = librosa.A_weighting(fft_frequencies)

            loudness = power_db + a_weighting[tf.newaxis, :]
            loudness = tf.reduce_mean(loudness, axis=-1)
            return loudness
def clip_to_window(boxlist, window, filter_nonoverlapping=True, scope=None):
    """Clip bounding boxes to a window.

  This op clips any input bounding boxes (represented by bounding box
  corners) to a window, optionally filtering out boxes that do not
  overlap at all with the window.

  Args:
    boxlist: BoxList holding M_in boxes
    window: a tensor of shape [4] representing the [y_min, x_min, y_max, x_max]
      window to which the op should clip boxes.
    filter_nonoverlapping: whether to filter out boxes that do not overlap at
      all with the window.
    scope: name scope.

  Returns:
    a BoxList holding M_out boxes where M_out <= M_in
  """
    with tf.name_scope(scope, 'ClipToWindow'):
        y_min, x_min, y_max, x_max = tf.split(value=boxlist.get(),
                                              num_or_size_splits=4,
                                              axis=1)
        win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window)
        y_min_clipped = tf.maximum(tf.minimum(y_min, win_y_max), win_y_min)
        y_max_clipped = tf.maximum(tf.minimum(y_max, win_y_max), win_y_min)
        x_min_clipped = tf.maximum(tf.minimum(x_min, win_x_max), win_x_min)
        x_max_clipped = tf.maximum(tf.minimum(x_max, win_x_max), win_x_min)
        clipped = box_list.BoxList(
            tf.concat(
                [y_min_clipped, x_min_clipped, y_max_clipped, x_max_clipped],
                1))
        clipped = _copy_extra_fields(clipped, boxlist)
        if filter_nonoverlapping:
            areas = area(clipped)
            nonzero_area_indices = tf.cast(
                tf.reshape(tf.where(tf.greater(areas, 0.0)), [-1]), tf.int32)
            clipped = gather(clipped, nonzero_area_indices)
        return clipped
Beispiel #21
0
        def _values_transform_fn(t, grid, value_grid):
            zero = tf.zeros_like(value_grid)
            is_ex_time = _is_exercise_time(t)

            def _at_least_one_swaption_pays():
                payoff_swap = tf.nn.relu(_get_swap_payoff(t))
                return tf.where(
                    tf.reshape(is_ex_time,
                               tf.concat([batch_shape, [1] * dim], axis=0)),
                    payoff_swap, zero)

            v_star = tf.cond(tf.math.reduce_any(is_ex_time),
                             _at_least_one_swaption_pays, lambda: zero)
            return grid, tf.maximum(value_grid, v_star)
Beispiel #22
0
def get_linear_warmup_rsqrt_decay_lr(init_lr, hidden_size, num_warmup_steps):
    """Calculate learning rate with linear warmup and rsqrt decay."""
    num_warmup_steps = tf.cast(num_warmup_steps, tf.float32)
    global_step = tf.compat.v1.train.get_or_create_global_step()
    global_step = tf.cast(global_step, tf.float32)

    learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
    learning_rate *= tf.math.rsqrt(tf.cast(hidden_size, tf.float32))
    # Apply linear warmup
    learning_rate *= tf.minimum(1.0, global_step / num_warmup_steps)
    # Apply rsqrt decay
    learning_rate *= tf.math.rsqrt(tf.maximum(global_step, num_warmup_steps))

    return learning_rate
Beispiel #23
0
def compute_power(audio,
                  sample_rate=16000,
                  frame_rate=250,
                  frame_size=1024,
                  range_db=LD_RANGE,
                  ref_db=20.7):
    """Compute power of audio in dB."""
    # TODO(hanoih@): enable `use_tf` to be True or False like `compute_loudness`
    rms_energy = compute_rms_energy(audio, sample_rate, frame_rate, frame_size)
    power_db = amplitude_to_db(rms_energy**2, use_tf=True)
    # Set dynamic range.
    power_db -= ref_db
    power_db = tf.maximum(power_db, -range_db)
    return power_db
def create_masks(inputs, target):
    """function"""
    encoder_mask = tf.cast(tf.math.equal(inputs, 0), tf.float32)
    encoder_mask = encoder_mask[:, tf.newaxis, tf.newaxis, :]
    decoder_mask = tf.cast(tf.math.equal(inputs, 0), tf.float32)
    decoder_mask = decoder_mask[:, tf.newaxis, tf.newaxis, :]
    batch_size, seq_len_out = target.shape
    look_ahead_mask = tf.linalg.band_part(tf.ones((seq_len_out, seq_len_out)),
                                          -1, 0)
    look_ahead_mask = 1 - look_ahead_mask
    padding_mask = tf.cast(tf.math.equal(target, 0), tf.float32)
    padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :]
    combined_mask = tf.maximum(look_ahead_mask, padding_mask)
    return encoder_mask, combined_mask, decoder_mask
Beispiel #25
0
def _crop(x, max_length, sample):
    """Select (optionally random) crop from sequence."""
    if sample:
        # Optionally sample random starting position.
        start = tf.random.uniform(
            (),
            dtype=tf.int32,
            maxval=tf.maximum(1,
                              tf.shape(x)[0] - max_length + 1))
    else:
        start = 0

    x = x[start:(start + max_length)]
    return x
def create_masks(inputs, target):
    """
    Returns: encoder_mask, look_ahead_mask, decoder_mask
    """
    encoder_mask = tf.cast(tf.math.equal(inputs, 0), tf.float32)
    encoder_mask = encoder_mask[:, tf.newaxis, tf.newaxis, :]
    decoder_mask = tf.cast(tf.math.equal(inputs, 0), tf.float32)
    decoder_mask = decoder_mask[:, tf.newaxis, tf.newaxis, :]
    size = target.shape[1]
    look_ahead_mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    dec_target_mask = tf.cast(tf.math.equal(target, 0), tf.float32)
    dec_target_mask = dec_target_mask[:, tf.newaxis, tf.newaxis, :]
    combined_mask = tf.maximum(dec_target_mask, look_ahead_mask)
    return (encoder_mask, combined_mask, decoder_mask)
Beispiel #27
0
def create_masks(inputs, target):
    """creates all masks for training/validation"""

    size = tf.shape(target)[1]
    encoder_mask = tf.cast(tf.math.equal(inputs, 0), tf.float32)
    encoder_mask = encoder_mask[:, tf.newaxis, tf.newaxis, :]

    look_ahead_mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)

    dec_target_padd_mask = tf.cast(tf.math.equal(target, 0), tf.float32)
    dec_target_padd_mask = dec_target_padd_mask[:, tf.newaxis, tf.newaxis, :]
    look_ahead_mask = tf.maximum(dec_target_padd_mask, look_ahead_mask)

    return encoder_mask, look_ahead_mask, encoder_mask
Beispiel #28
0
def berp(global_step, start_step, end_step, start_val, end_val, alpha=5):
    """Beta interpolation."""
    beta_dist = tfd.Beta(alpha, alpha)
    mode = beta_dist.mode()
    interp = (tf.cast(global_step - start_step, tf.float32) /
              tf.cast(end_step - start_step, tf.float32))
    interp = tf.maximum(0.0, tf.minimum(1.0, interp))
    interp = tf.where(tf.math.is_nan(interp), tf.zeros_like(interp), interp)
    interp *= mode
    val = beta_dist.prob(interp)
    val /= beta_dist.prob(mode)
    val *= (end_val - start_val)
    val += start_val
    return val
Beispiel #29
0
def _pad_multichannel(image, ensure_small, pad_value, mode):
  """Pad to ensure `ensure_small`."""
  pad_h = tf.maximum(ensure_small[0] - tf.shape(image)[0], 0)
  pad_h_l = pad_h // 2
  pad_h_r = pad_h - pad_h_l

  pad_w = tf.maximum(ensure_small[1] - tf.shape(image)[1], 0)
  pad_w_l = pad_w // 2
  pad_w_r = pad_w - pad_w_l

  def pad_2d(x, v):
    """Pad 2D input `x` with constant value `v`."""
    return tf.pad(
        x, [[pad_h_l, pad_h_r], [pad_w_l, pad_w_r]] + [[0, 0]] *
        (len(x.shape) - 2), mode=mode, constant_values=v)

  if isinstance(pad_value, (list, tuple)):
    image_new = tf.stack(
        [pad_2d(image[:, :, i], v) for i, v in enumerate(pad_value)], axis=2)
  else:
    image_new = pad_2d(image, pad_value)

  return image_new
Beispiel #30
0
def get_max_num_levels(sz):
  """Returns the maximum number of levels that construct() can support.

  Args:
    sz: A tuple of ints representing some input size (batch, width, height).

  Returns:
    The maximum value for num_levels, when calling construct(im, num_levels),
    assuming `sz` is the shape of `im`.
  """
  min_sz = tf.minimum(sz[1], sz[2])
  log2 = lambda x: tf.math.log(tf.cast(x, tf.float32)) / tf.math.log(2.)
  max_num_levels = tf.cast(tf.math.ceil(log2(tf.maximum(1, min_sz))), tf.int32)
  return max_num_levels