def train(self, sentences):
    token_ids, token_values, token_dense_shape = self._tokenize(sentences)
    tokens_sparse = tf.sparse.SparseTensor(
        indices=token_ids, values=token_values, dense_shape=token_dense_shape)
    tokens = tf.sparse.to_dense(tokens_sparse, default_value="")

    sparse_lookup_ids = tf.sparse.SparseTensor(
        indices=tokens_sparse.indices,
        values=self._words_to_indices(tokens_sparse.values),
        dense_shape=tokens_sparse.dense_shape)
    lookup_ids = tf.sparse.to_dense(sparse_lookup_ids, default_value=0)

    # Targets are the next word for each word of the sentence.
    tokens_ids_seq = lookup_ids[:, 0:-1]
    tokens_ids_target = lookup_ids[:, 1:]

    tokens_prefix = tokens[:, 0:-1]

    # Mask determining which positions we care about for a loss: all positions
    # that have a valid non-terminal token.
    mask = tf.logical_and(
        tf.logical_not(tf.equal(tokens_prefix, "")),
        tf.logical_not(tf.equal(tokens_prefix, "<E>")))

    input_mask = tf.cast(mask, tf.int32)

    with tf.GradientTape() as t:
      sentence_embeddings = tf.nn.embedding_lookup(self._embeddings,
                                                   tokens_ids_seq)

      lstm_initial_state = self._lstm_cell.get_initial_state(
          sentence_embeddings)

      lstm_output = self._rnn_layer(
          inputs=sentence_embeddings, initial_state=lstm_initial_state)

      # Stack LSTM outputs into a batch instead of a 2D array.
      lstm_output = tf.reshape(lstm_output, [-1, self._lstm_cell.output_size])

      logits = self._logit_layer(lstm_output)

      targets = tf.reshape(tokens_ids_target, [-1])
      weights = tf.cast(tf.reshape(input_mask, [-1]), tf.float32)

      losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
          labels=targets, logits=logits)

      # Final loss is the mean loss for all token losses.
      final_loss = tf.math.divide(
          tf.reduce_sum(tf.multiply(losses, weights)),
          tf.reduce_sum(weights),
          name="final_loss")

    watched = t.watched_variables()
    gradients = t.gradient(final_loss, watched)

    for w, g in zip(watched, gradients):
      w.assign_sub(g)

    return final_loss
def barrier_price(*,
                  volatilities,
                  strikes,
                  expiries,
                  spots,
                  barriers,
                  rebates=None,
                  discount_rates=None,
                  continuous_dividends=None,
                  cost_of_carries=None,
                  is_barrier_down=None,
                  is_knock_out=None,
                  is_call_options=None,
                  dtype=None,
                  name=None):
    """Prices barrier options in a Black-Scholes Model.

  Computes the prices of options with a single barrier in Black-Scholes world as
  described in Ref. [1]. Note that the barrier is applied continuously.

  #### Example

  This example is taken from Ref. [2], Page 154.

  ```python
  import tf_quant_finance as tff

  dtype = np.float32
  discount_rates = np.array([.08, .08])
  continuous_dividends = np.array([.04, .04])
  spots = np.array([100., 100.])
  strikes = np.array([90., 90.])
  barriers = np.array([95. 95.])
  rebates = np.array([3. 3.])
  volatilities = np.array([.25, .25])
  expiries = np.array([.5, .5])
  barriers_type = np.array([5, 1])
  is_barrier_down = np.array([True, False])
  is_knock_out = np.array([False, False])
  is_call_option = np.array([True, True])

  price = tff.black_scholes.barrier_price(
    discount_rates, continuous_dividends, spots, strikes,
    barriers, rebates, volatilities,
    expiries, is_barrier_down, is_knock_out, is_call_options)

  # Expected output
  #  `Tensor` with values [9.024, 7.7627]
  ```

  #### References

  [1]: Lee Clewlow, Javier Llanos, Chris Strickland, Caracas Venezuela
    Pricing Exotic Options in a Black-Scholes World, 1994
    https://warwick.ac.uk/fac/soc/wbs/subjects/finance/research/wpaperseries/1994/94-54.pdf
  [2]: Espen Gaarder Haug, The Complete Guide to Option Pricing Formulas,
    2nd Edition, 1997

  Args:
    volatilities: Real `Tensor` of any shape and dtype. The volatilities to
      expiry of the options to price.
    strikes: A real `Tensor` of the same dtype and compatible shape as
      `volatilities`. The strikes of the options to be priced.
    expiries: A real `Tensor` of same dtype and compatible shape as
      `volatilities`. The expiry of each option. The units should be such that
      `expiry * volatility**2` is dimensionless.
    spots: A real `Tensor` of any shape that broadcasts to the shape of the
      `volatilities`. The current spot price of the underlying.
    barriers: A real `Tensor` of same dtype as the `volatilities` and of the
      shape that broadcasts with `volatilities`. The barriers of each option.
    rebates: A real `Tensor` of same dtype as the `volatilities` and of the
      shape that broadcasts with `volatilities`. For knockouts, this is a
      fixed cash payout in case the barrier is breached. For knockins, this is a
      fixed cash payout in case the barrier level is not breached. In the former
      case, the rebate is paid immediately on breach whereas in the latter, the
      rebate is paid at the expiry of the option.
      Default value: `None` which maps to no rebates.
    discount_rates: A real `Tensor` of same dtype as the
      `volatilities` and of the shape that broadcasts with `volatilities`.
      Discount rates, or risk free rates.
      Default value: `None`, equivalent to discount_rate = 0.
    continuous_dividends: A real `Tensor` of same dtype as the
      `volatilities` and of the shape that broadcasts with `volatilities`. A
      continuous dividend rate paid by the underlier. If `None`, then
      defaults to zero dividends.
      Default value: `None`, equivalent to zero dividends.
    cost_of_carries: A optional real `Tensor` of same dtype as the
      `volatilities` and of the shape that broadcasts with `volatilities`.
      Cost of storing a physical commodity, the cost of interest paid when
      long, or the opportunity cost, or the cost of paying dividends when short.
      If not `None`, `continuous_dividends` is calculated as r - c,
      where r are the `discount_rates` and c is `cost_of_carries`.
    is_barrier_down: A real `Tensor` of `boolean` values and of the shape
      that broadcasts with `volatilities`. True if barrier is below asset
      price at expiration.
      Default value: `True`.
    is_knock_out: A real `Tensor` of `boolean` values and of the shape
      that broadcasts with `volatilities`. True if option is knock out
      else false.
      Default value: `True`.
    is_call_options: A real `Tensor` of `boolean` values and of the shape
      that broadcasts with `volatilities`. True if option is call else
      false.
      Default value: `True`.
    dtype: Optional `tf.DType`. If supplied, the dtype to be used for conversion
      of any supplied non-`Tensor` arguments to `Tensor`.
      Default value: `None` which maps to the default dtype inferred by
      TensorFlow.
    name: str. The name for the ops created by this function.
      Default value: `None` which is mapped to the default name `barrier_price`.
  Returns:
    option_prices: A `Tensor` of same shape as `spots`. The approximate price of
    the barriers option under black scholes.
  """
    # The computation is done as in Ref [2] where each integral is split into
    # two matrices. The first matrix contains the algebraic terms and the second
    # matrix contains the probability distribution terms. Masks are used to filter
    # appropriate terms for calculating the integral. Then a dot product of each
    # row in the matricies coupled with the masks work to calculate the prices of
    # the barriers option.
    if (continuous_dividends is not None) and (cost_of_carries is not None):
        raise ValueError(
            'At most one of continuous_dividends and cost of carries '
            'may be supplied')
    with tf.name_scope(name or 'barrier_price'):
        spots = tf.convert_to_tensor(spots, dtype=dtype, name='spots')
        dtype = spots.dtype
        strikes = tf.convert_to_tensor(strikes, dtype=dtype, name='strikes')
        volatilities = tf.convert_to_tensor(volatilities,
                                            dtype=dtype,
                                            name='volatilities')
        expiries = tf.convert_to_tensor(expiries, dtype=dtype, name='expiries')
        barriers = tf.convert_to_tensor(barriers, dtype=dtype, name='barriers')
        if rebates is not None:
            rebates = tf.convert_to_tensor(rebates,
                                           dtype=dtype,
                                           name='rebates')
        else:
            rebates = tf.zeros_like(spots, dtype=dtype, name='rebates')

        # Convert all to tensor and enforce float dtype where required
        if discount_rates is not None:
            discount_rates = tf.convert_to_tensor(discount_rates,
                                                  dtype=dtype,
                                                  name='discount_rates')
        else:
            discount_rates = tf.convert_to_tensor(1,
                                                  dtype=dtype,
                                                  name='discount_rates')

        if continuous_dividends is None:
            continuous_dividends = tf.convert_to_tensor(
                0.0, dtype=dtype, name='continuous_dividends')

        if cost_of_carries is not None:
            cost_of_carries = tf.convert_to_tensor(cost_of_carries,
                                                   dtype=dtype,
                                                   name='cost_of_carries')
        else:
            cost_of_carries = discount_rates - continuous_dividends

        if is_barrier_down is None:
            is_barrier_down = tf.constant(1, name='is_barrier_down')
        else:
            is_barrier_down = tf.convert_to_tensor(is_barrier_down,
                                                   dtype=tf.bool,
                                                   name='is_barrier_down')
            is_barrier_down = tf.where(is_barrier_down, 1, 0)
        if is_knock_out is None:
            is_knock_out = tf.constant(1, name='is_knock_out')
        else:
            is_knock_out = tf.convert_to_tensor(is_knock_out,
                                                dtype=tf.bool,
                                                name='is_knock_out')
            is_knock_out = tf.where(is_knock_out, 1, 0)
        if is_call_options is None:
            is_call_options = tf.constant(1, name='is_call_options')
        else:
            is_call_options = tf.convert_to_tensor(is_call_options,
                                                   dtype=tf.bool,
                                                   name='is_call_options')
            is_call_options = tf.where(is_call_options, 1, 0)

        # Indices which range from 0-7 are used to select the appropriate
        # mask for each barrier
        indices = tf.bitwise.left_shift(is_barrier_down,
                                        2) + tf.bitwise.left_shift(
                                            is_knock_out, 1) + is_call_options

        # Masks select the appropriate terms for integral approximations
        # Integrals are seperated by algebraic terms and probability
        # distribution terms. This give 12 different terms per matrix
        # (6 integrals, 2 terms each)
        # shape = [8, 12]
        mask_matrix_greater_strike = tf.constant([
            [1, 1, -1, -1, 0, 0, 1, 1, 1, 1, 0, 0],  # up and in put
            [1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],  # up and in call
            [0, 0, 1, 1, 0, 0, -1, -1, 0, 0, 1, 1],  # up and out put
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],  # up and out call
            [0, 0, 1, 1, -1, -1, 1, 1, 0, 0, 1, 1],  # down and in put
            [0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0],  # down and in call
            [1, 1, -1, -1, 1, 1, -1, -1, 0, 0, 1, 1],  # down and out put
            [1, 1, 0, 0, -1, -1, 0, 0, 0, 0, 1, 1]
        ])  # down and out call

        mask_matrix_lower_strike = tf.constant([
            [0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0],  # up and in put
            [0, 0, 1, 1, -1, -1, 1, 1, 1, 1, 0, 0],  # up and in call
            [1, 1, 0, 0, -1, -1, 0, 0, 0, 0, 1, 1],  # up and out put
            [1, 1, -1, -1, 1, 1, -1, -1, 0, 0, 1, 1],  # up and out call
            [1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],  # down and in put
            [1, 1, -1, -1, 0, 0, 1, 1, 1, 1, 0, 0],  # down and in call
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],  # down and out put
            [0, 0, 1, 1, 0, 0, -1, -1, 0, 0, 1, 1]
        ])  # down and out call

        # Create masks
        # Masks are shape [strikes.shape, 12]
        masks_lower = tf.gather(mask_matrix_lower_strike, indices, axis=0)
        masks_greater = tf.gather(mask_matrix_greater_strike, indices, axis=0)
        strikes_greater = tf.expand_dims(strikes > barriers, axis=-1)
        masks = tf.where(strikes_greater, masks_greater, masks_lower)
        masks = tf.cast(masks, dtype=dtype)
        one = tf.constant(1, dtype=dtype)
        call_or_put = tf.cast(tf.where(tf.equal(is_call_options, 0), -one,
                                       one),
                              dtype=dtype)
        below_or_above = tf.cast(tf.where(tf.equal(is_barrier_down, 0), -one,
                                          one),
                                 dtype=dtype)

        # Calculate params for integrals
        sqrt_var = volatilities * tf.math.sqrt(expiries)
        mu = (cost_of_carries) - ((volatilities**2) / 2)
        lamda = 1 + (mu / (volatilities**2))
        x = (tf.math.log(spots / strikes) / (sqrt_var)) + (lamda * sqrt_var)
        x1 = (tf.math.log(spots / barriers) / (sqrt_var)) + (lamda * sqrt_var)
        y = (tf.math.log((barriers**2) / (spots * strikes)) /
             (sqrt_var)) + (lamda * sqrt_var)
        y1 = (tf.math.log(barriers / spots) / (sqrt_var)) + (lamda * sqrt_var)
        b = ((mu**2) +
             (2 * (volatilities**2) * discount_rates)) / (volatilities**2)
        z = (tf.math.log(barriers / spots) / (sqrt_var)) + (b * sqrt_var)
        a = mu / (volatilities**2)

        # Other params used for integrals
        discount_rates_exponent = tf.math.exp(-discount_rates * expiries,
                                              name='discount_rates_exponent')
        continuous_dividends_exponent = tf.math.exp(
            (cost_of_carries - discount_rates) * expiries,
            name='continuous_dividends_exponent')
        barriers_ratio = tf.math.divide(barriers, spots, name='barriers_ratio')
        spots_term = call_or_put * spots * continuous_dividends_exponent
        strikes_term = call_or_put * strikes * discount_rates_exponent

        # rank is used to stack elements and reduce_sum
        strike_rank = strikes.shape.rank

        # Constructing Matrix with first and second algebraic terms for each
        # integral [strike.shape, 12]
        terms_mat = tf.stack(
            (spots_term, -strikes_term, spots_term, -strikes_term, spots_term *
             (barriers_ratio**(2 * lamda)), -strikes_term *
             (barriers_ratio**((2 * lamda) - 2)), spots_term *
             (barriers_ratio**(2 * lamda)), -strikes_term *
             (barriers_ratio**((2 * lamda) - 2)), rebates *
             discount_rates_exponent, -rebates * discount_rates_exponent *
             (barriers_ratio**((2 * lamda) - 2)), rebates *
             (barriers_ratio**(a + b)), rebates * (barriers_ratio**(a - b))),
            name='term_matrix',
            axis=strike_rank)

        # Constructing Matrix with first and second norm for each integral
        # [strikes.shape, 12]
        cdf_mat = tf.stack(
            (call_or_put * x, call_or_put *
             (x - sqrt_var), call_or_put * x1, call_or_put *
             (x1 - sqrt_var), below_or_above * y, below_or_above *
             (y - sqrt_var), below_or_above * y1, below_or_above *
             (y1 - sqrt_var), below_or_above *
             (x1 - sqrt_var), below_or_above * (y1 - sqrt_var),
             below_or_above * z, below_or_above * (z - (2 * b * sqrt_var))),
            name='cdf_matrix',
            axis=strike_rank)
        cdf_mat = _ncdf(cdf_mat)
        # Calculating and returning price for each option
        return tf.reduce_sum(masks * terms_mat * cdf_mat, axis=strike_rank)
    def testBijector(self, bijector_name, data):
        tfp_hps.guitar_skip_if_matches('Tanh', bijector_name, 'b/144163991')

        bijector, event_dim = self._draw_bijector(bijector_name, data)

        # Forward mapping: Check differentiation through forward mapping with
        # respect to the input and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        xs = self._draw_domain_tensor(bijector, data, event_dim)
        wrt_vars = [xs] + [
            v for v in bijector.trainable_variables if v.dtype.is_floating
        ]
        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `forward` of {}'.format(bijector)):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ys = bijector.forward(xs + 0)
        grads = tape.gradient(ys, wrt_vars)
        assert_no_none_grad(bijector, 'forward', wrt_vars, grads)

        # For scalar bijectors, verify correctness of the _is_increasing method.
        # TODO(b/148459057): Except, don't verify Softfloor on Guitar because
        # of numerical problem.
        def exception(bijector):
            if not tfp_hps.running_under_guitar():
                return False
            if isinstance(bijector, tfb.Softfloor):
                return True
            if isinstance(bijector, tfb.Invert):
                return exception(bijector.bijector)
            return False

        if (bijector.forward_min_event_ndims == 0
                and bijector.inverse_min_event_ndims == 0
                and not exception(bijector)):
            dydx = grads[0]
            hp.note('dydx: {}'.format(dydx))
            isfinite = tf.math.is_finite(dydx)
            incr_or_slope_eq0 = bijector._internal_is_increasing() | tf.equal(
                dydx, 0)  # pylint: disable=protected-access
            self.assertAllEqual(
                isfinite & incr_or_slope_eq0,
                isfinite & (dydx >= 0) | tf.zeros_like(incr_or_slope_eq0))

        # FLDJ: Check differentiation through forward log det jacobian with
        # respect to the input and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        event_ndims = data.draw(
            hps.integers(min_value=bijector.forward_min_event_ndims,
                         max_value=xs.shape.ndims))
        with tf.GradientTape() as tape:
            max_permitted = _ldj_tensor_conversions_allowed(bijector,
                                                            is_forward=True)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `forward_log_det_jacobian` of {}'.format(bijector),
                    max_permissible=max_permitted):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ldj = bijector.forward_log_det_jacobian(
                    xs + 0, event_ndims=event_ndims)
        grads = tape.gradient(ldj, wrt_vars)
        assert_no_none_grad(bijector, 'forward_log_det_jacobian', wrt_vars,
                            grads)

        # Inverse mapping: Check differentiation through inverse mapping with
        # respect to the codomain "input" and parameter variables.  Also check that
        # any variables are not referenced overmuch.
        ys = self._draw_codomain_tensor(bijector, data, event_dim)
        wrt_vars = [ys] + [
            v for v in bijector.trainable_variables if v.dtype.is_floating
        ]
        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `inverse` of {}'.format(bijector)):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                xs = bijector.inverse(ys + 0)
        grads = tape.gradient(xs, wrt_vars)
        assert_no_none_grad(bijector, 'inverse', wrt_vars, grads)

        # ILDJ: Check differentiation through inverse log det jacobian with respect
        # to the codomain "input" and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        event_ndims = data.draw(
            hps.integers(min_value=bijector.inverse_min_event_ndims,
                         max_value=ys.shape.ndims))
        with tf.GradientTape() as tape:
            max_permitted = _ldj_tensor_conversions_allowed(bijector,
                                                            is_forward=False)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `inverse_log_det_jacobian` of {}'.format(bijector),
                    max_permissible=max_permitted):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ldj = bijector.inverse_log_det_jacobian(
                    ys + 0, event_ndims=event_ndims)
        grads = tape.gradient(ldj, wrt_vars)
        assert_no_none_grad(bijector, 'inverse_log_det_jacobian', wrt_vars,
                            grads)

        # Check that the outputs of forward_dtype and inverse_dtype match the dtypes
        # of the outputs of forward and inverse.
        self.assertAllEqualNested(ys.dtype, bijector.forward_dtype(xs.dtype))
        self.assertAllEqualNested(xs.dtype, bijector.inverse_dtype(ys.dtype))
示例#4
0
 def loop_cond(i, decodes_BxT, unused_cache_BxU_dict):
     finished_B = tf.reduce_any(tf.equal(decodes_BxT, eos_id), axis=1)
     return tf.logical_and(i < max_decode_len,
                           tf.logical_not(tf.reduce_all(finished_B)))
示例#5
0
class PredicatesTest(test_util.TestCase):
    @parameterized.named_parameters(
        dict(testcase_name='_greater_true',
             predicate=ps.greater,
             args_fn=lambda: [tf.constant(1), tf.constant(0)],
             kwargs=dict(),
             expected=True),
        dict(testcase_name='_greater_false',
             predicate=ps.greater,
             args_fn=lambda: [tf.constant(-.1),
                              tf.constant(0.)],
             kwargs=dict(),
             expected=False),
        dict(testcase_name='_greater_none',
             predicate=ps.greater,
             args_fn=lambda: [tf.constant(1) + tf.constant(0),
                              tf.constant(0)],
             kwargs=dict(),
             expected=True),
        dict(testcase_name='_less_true',
             predicate=ps.less,
             args_fn=lambda: [tf.constant(-1), tf.constant(0)],
             kwargs=dict(),
             expected=True),
        dict(testcase_name='_log',
             predicate=ps.log,
             args_fn=lambda: [tf.constant(1.)],
             kwargs=dict(),
             expected=0.),
        dict(testcase_name='_equal_true',
             predicate=ps.equal,
             args_fn=lambda: [tf.constant(0), 0],
             kwargs=dict(),
             expected=True),
        dict(testcase_name='_equal_false',
             predicate=ps.equal,
             args_fn=lambda: [tf.constant(1), tf.constant(0)],
             kwargs=dict(),
             expected=False),
        dict(testcase_name='_and_true',
             predicate=ps.logical_and,
             args_fn=lambda: [True, tf.constant(True)],
             kwargs=dict(),
             expected=True),
        dict(testcase_name='_and_none',
             predicate=ps.logical_and,
             args_fn=lambda: [tf.constant(True),
                              tf.equal(1, 1)],
             kwargs=dict(),
             expected=True),
        dict(testcase_name='_or_true',
             predicate=ps.logical_or,
             args_fn=lambda: [tf.constant(True),
                              tf.constant(False)],
             kwargs=dict(),
             expected=True),
        dict(testcase_name='_all_true',
             predicate=ps.reduce_all,
             args_fn=lambda: [[tf.constant(True)] * 10],
             kwargs=dict(),
             expected=True),
        dict(
            testcase_name='_all_false',
            predicate=ps.reduce_all,
            args_fn=lambda: [[
                tf.constant(True),  # pylint: disable=g-long-lambda
                True,
                tf.constant(False)
            ]],
            kwargs=dict(),
            expected=False),
        dict(
            testcase_name='_all_with_axis',
            predicate=ps.reduce_all,
            args_fn=lambda: (
                [
                    [True, tf.constant(True)],  # pylint: disable=g-long-lambda
                    [False, True]
                ], ),
            kwargs=dict(axis=1),
            expected=[True, False]),
        dict(
            testcase_name='_all_with_name',
            predicate=ps.reduce_all,
            args_fn=lambda: (
                [
                    [True, tf.constant(True)],  # pylint: disable=g-long-lambda
                    [False, True]
                ], ),
            kwargs=dict(axis=1, name='my_name'),
            expected=[True, False]),
        dict(
            testcase_name='_any_true',
            predicate=ps.reduce_any,
            args_fn=lambda: [[
                tf.constant(True),  # pylint: disable=g-long-lambda
                tf.constant(False),
                tf.constant(False)
            ]],
            kwargs=dict(),
            expected=True),
        dict(testcase_name='_any_false',
             predicate=ps.reduce_any,
             args_fn=lambda: [[tf.constant(False)] * 23],
             kwargs=dict(),
             expected=False),
        dict(
            testcase_name='_any_keepdims',
            predicate=ps.reduce_any,
            args_fn=lambda: (
                [
                    [True, tf.constant(True)],  # pylint: disable=g-long-lambda
                    [False, True]
                ], ),
            kwargs=dict(keepdims=True),
            expected=[[True]]),
    )
    def test_static_predicate(self, predicate, args_fn, kwargs, expected):
        actual = predicate(*args_fn(), **kwargs)
        self.assertAllCloseAccordingToType(expected, actual)
def interpolate(x,
                x_data,
                y_data,
                left_slope=None,
                right_slope=None,
                validate_args=False,
                optimize_for_tpu=False,
                dtype=None,
                name=None):
    """Performs linear interpolation for supplied points.

  Given a set of knots whose x- and y- coordinates are in `x_data` and `y_data`,
  this function returns y-values for x-coordinates in `x` via piecewise
  linear interpolation.

  `x_data` must be non decreasing, but `y_data` don't need to be because we do
  not require the function approximated by these knots to be monotonic.

  #### Examples

  ```python
  x = [-10, -1, 1, 3, 6, 7, 8, 15, 18, 25, 30, 35]
  x_data = [-1, 2, 6, 8, 18, 30.0]
  y_data = [10, -1, -5, 7, 9, 20]

  result = linear_interpolation(x, x_data, y_data)
  # [ 10, 10, 2.66666667, -2, -5, 1, 7, 8.4, 9, 15.41666667, 20, 20]
  ```

  Args:
    x: x-coordinates for which we need to get interpolation. A N-D `Tensor` of
      real dtype. First N-1 dimensions represent batching dimensions.
    x_data: x coordinates. A N-D `Tensor` of real dtype. Should be sorted
      in non decreasing order. First N-1 dimensions represent batching
      dimensions.
    y_data: y coordinates. A N-D `Tensor` of real dtype. Should have the
      compatible shape as `x_data`. First N-1 dimensions represent batching
      dimensions.
    left_slope: The slope to use for extrapolation with x-coordinate smaller
      than the min `x_data`. It's a 0-D or N-D `Tensor`.
      Default value: `None`, which maps to `0.0` meaning constant extrapolation,
      i.e. extrapolated value will be the leftmost `y_data`.
    right_slope: The slope to use for extrapolation with x-coordinate greater
      than the max `x_data`. It's a 0-D or N-D `Tensor`.
      Default value: `None` which maps to `0.0` meaning constant extrapolation,
      i.e. extrapolated value will be the rightmost `y_data`.
    validate_args: Python `bool` that indicates whether the function performs
      the check if the shapes of `x_data` and `y_data` are equal and that the
      elements in `x_data` are non decreasing. If this value is set to `False`
      and the elements in `x_data` are not increasing, the result of linear
      interpolation may be wrong.
      Default value: `False`.
    optimize_for_tpu: A Python bool. If `True`, the algorithm uses one-hot
      encoding to lookup indices of `x_values` in `x_data`. This significantly
      improves performance of the algorithm on a TPU device but may slow down
      performance on the CPU.
      Default value: `False`.
    dtype: Optional tf.dtype for `x`, x_data`, `y_data`, `left_slope` and
      `right_slope`.
      Default value: `None` which means that the `dtype` inferred by TensorFlow
      is used.
    name: Python str. The name prefixed to the ops created by this function.
      Default value: `None` which maps to 'linear_interpolation'.

  Returns:
    A N-D `Tensor` of real dtype corresponding to the x-values in `x`.
  """
    name = name or "linear_interpolation"
    with tf.name_scope(name):
        x = tf.convert_to_tensor(x, dtype=dtype, name="x")
        dtype = dtype or x.dtype
        x_data = tf.convert_to_tensor(x_data, dtype=dtype, name="x_data")
        y_data = tf.convert_to_tensor(y_data, dtype=dtype, name="y_data")
        # Try broadcast batch_shapes
        x, x_data = utils.broadcast_common_batch_shape(x, x_data)
        x, y_data = utils.broadcast_common_batch_shape(x, y_data)
        x_data, y_data = utils.broadcast_common_batch_shape(x_data, y_data)

        batch_shape = x.shape.as_list()[:-1]
        if not batch_shape:
            x = tf.expand_dims(x, 0)
            x_data = tf.expand_dims(x_data, 0)
            y_data = tf.expand_dims(y_data, 0)

        if left_slope is None:
            left_slope = tf.constant(0.0, dtype=x.dtype, name="left_slope")
        else:
            left_slope = tf.convert_to_tensor(left_slope,
                                              dtype=dtype,
                                              name="left_slope")
        if right_slope is None:
            right_slope = tf.constant(0.0, dtype=x.dtype, name="right_slope")
        else:
            right_slope = tf.convert_to_tensor(right_slope,
                                               dtype=dtype,
                                               name="right_slope")
        control_deps = []
        if validate_args:
            # Check that `x_data` elements is non-decreasing
            diffs = x_data[..., 1:] - x_data[..., :-1]
            assertion = tf.compat.v1.debugging.assert_greater_equal(
                diffs,
                tf.zeros_like(diffs),
                message="x_data is not sorted in non-decreasing order.")
            control_deps.append(assertion)
            # Check that the shapes of `x_data` and `y_data` are equal
            control_deps.append(
                tf.compat.v1.assert_equal(tf.shape(x_data), tf.shape(y_data)))

        with tf.control_dependencies(control_deps):
            # Get upper bound indices for `x`.
            upper_indices = tf.searchsorted(x_data,
                                            x,
                                            side="left",
                                            out_type=tf.int32)
            x_data_size = x_data.shape.as_list()[-1]
            at_min = tf.equal(upper_indices, 0)
            at_max = tf.equal(upper_indices, x_data_size)
            # Create tensors in order to be used by `tf.where`.
            # `values_min` are extrapolated values for x-coordinates less than or
            # equal to `x_data[..., 0]`.
            # `values_max` are extrapolated values for x-coordinates greater than
            # `x_data[..., -1]`.

            values_min = tf.expand_dims(
                y_data[..., 0], -1) + left_slope * (x - tf.broadcast_to(
                    tf.expand_dims(x_data[..., 0], -1), shape=tf.shape(x)))
            values_max = tf.expand_dims(
                y_data[..., -1], -1) + right_slope * (x - tf.broadcast_to(
                    tf.expand_dims(x_data[..., -1], -1), shape=tf.shape(x)))

            # `tf.where` evaluates all branches, need to cap indices to ensure it
            # won't go out of bounds.
            lower_encoding = tf.math.maximum(upper_indices - 1, 0)
            upper_encoding = tf.math.minimum(upper_indices, x_data_size - 1)
            # Prepare indices for `tf.gather` or `tf.one_hot`
            # TODO(b/156720909): Extract get_slice logic into a common utilities
            # module for cubic and linear interpolation
            if optimize_for_tpu:
                lower_encoding = tf.one_hot(lower_encoding,
                                            x_data_size,
                                            dtype=dtype)
                upper_encoding = tf.one_hot(upper_encoding,
                                            x_data_size,
                                            dtype=dtype)

            def get_slice(x, encoding):
                if optimize_for_tpu:
                    return tf.math.reduce_sum(tf.expand_dims(x, axis=-2) *
                                              encoding,
                                              axis=-1)
                else:
                    return tf.gather(x,
                                     encoding,
                                     axis=-1,
                                     batch_dims=x.shape.rank - 1)

            x_data_lower = get_slice(x_data, lower_encoding)
            x_data_upper = get_slice(x_data, upper_encoding)
            y_data_lower = get_slice(y_data, lower_encoding)
            y_data_upper = get_slice(y_data, upper_encoding)

            # Nan in unselected branches could propagate through gradient calculation,
            # hence we need to clip the values to ensure no nan would occur. In this
            # case we need to ensure there is no division by zero.
            x_data_diff = x_data_upper - x_data_lower
            floor_x_diff = tf.where(at_min | at_max, x_data_diff + 1,
                                    x_data_diff)
            interpolated = y_data_lower + (x - x_data_lower) * (
                y_data_upper - y_data_lower) / floor_x_diff

            interpolated = tf.where(at_min, values_min, interpolated)
            interpolated = tf.where(at_max, values_max, interpolated)
            if batch_shape:
                return interpolated
            else:
                return tf.squeeze(interpolated, 0)
    def testBijector(self, bijector_name, data):
        tfp_hps.guitar_skip_if_matches('Tanh', bijector_name, 'b/144163991')
        if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'):
            return
        event_dim = data.draw(hps.integers(min_value=2, max_value=6))
        bijector = data.draw(
            bijectors(bijector_name=bijector_name,
                      event_dim=event_dim,
                      enable_vars=True))
        self.evaluate(tf.group(*[v.initializer for v in bijector.variables]))

        # Forward mapping: Check differentiation through forward mapping with
        # respect to the input and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        # TODO(axch): Would be nice to get rid of all this shape inference logic and
        # just rely on a notion of batch and event shape for bijectors, so we can
        # pass those through `domain_tensors` and `codomain_tensors` and use
        # `tensors_in_support`.  However, `RationalQuadraticSpline` behaves weirdly
        # somehow and I got confused.
        codomain_event_shape = [event_dim] * bijector.inverse_min_event_ndims
        codomain_event_shape = constrain_inverse_shape(bijector,
                                                       codomain_event_shape)
        shp = bijector.inverse_event_shape(codomain_event_shape)
        shp = tensorshape_util.concatenate(
            data.draw(
                tfp_hps.broadcast_compatible_shape(
                    shp[:shp.ndims - bijector.forward_min_event_ndims])),
            shp[shp.ndims - bijector.forward_min_event_ndims:])
        xs = tf.identity(data.draw(domain_tensors(bijector, shape=shp)),
                         name='xs')
        wrt_vars = [xs] + [
            v for v in bijector.trainable_variables if v.dtype.is_floating
        ]
        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `forward` of {}'.format(bijector)):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ys = bijector.forward(xs + 0)
        grads = tape.gradient(ys, wrt_vars)
        assert_no_none_grad(bijector, 'forward', wrt_vars, grads)

        # For scalar bijectors, verify correctness of the _is_increasing method.
        if (bijector.forward_min_event_ndims == 0
                and bijector.inverse_min_event_ndims == 0):
            dydx = grads[0]
            hp.note('dydx: {}'.format(dydx))
            isfinite = tf.math.is_finite(dydx)
            incr_or_slope_eq0 = bijector._internal_is_increasing() | tf.equal(
                dydx, 0)  # pylint: disable=protected-access
            self.assertAllEqual(
                isfinite & incr_or_slope_eq0,
                isfinite & (dydx >= 0) | tf.zeros_like(incr_or_slope_eq0))

        # FLDJ: Check differentiation through forward log det jacobian with
        # respect to the input and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        event_ndims = data.draw(
            hps.integers(min_value=bijector.forward_min_event_ndims,
                         max_value=xs.shape.ndims))
        with tf.GradientTape() as tape:
            max_permitted = _ldj_tensor_conversions_allowed(bijector,
                                                            is_forward=True)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `forward_log_det_jacobian` of {}'.format(bijector),
                    max_permissible=max_permitted):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ldj = bijector.forward_log_det_jacobian(
                    xs + 0, event_ndims=event_ndims)
        grads = tape.gradient(ldj, wrt_vars)
        assert_no_none_grad(bijector, 'forward_log_det_jacobian', wrt_vars,
                            grads)

        # Inverse mapping: Check differentiation through inverse mapping with
        # respect to the codomain "input" and parameter variables.  Also check that
        # any variables are not referenced overmuch.
        domain_event_shape = [event_dim] * bijector.forward_min_event_ndims
        domain_event_shape = constrain_forward_shape(bijector,
                                                     domain_event_shape)
        shp = bijector.forward_event_shape(domain_event_shape)
        shp = tensorshape_util.concatenate(
            data.draw(
                tfp_hps.broadcast_compatible_shape(
                    shp[:shp.ndims - bijector.inverse_min_event_ndims])),
            shp[shp.ndims - bijector.inverse_min_event_ndims:])
        ys = tf.identity(data.draw(codomain_tensors(bijector, shape=shp)),
                         name='ys')
        wrt_vars = [ys] + [
            v for v in bijector.trainable_variables if v.dtype.is_floating
        ]
        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `inverse` of {}'.format(bijector)):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                xs = bijector.inverse(ys + 0)
        grads = tape.gradient(xs, wrt_vars)
        assert_no_none_grad(bijector, 'inverse', wrt_vars, grads)

        # ILDJ: Check differentiation through inverse log det jacobian with respect
        # to the codomain "input" and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        event_ndims = data.draw(
            hps.integers(min_value=bijector.inverse_min_event_ndims,
                         max_value=ys.shape.ndims))
        with tf.GradientTape() as tape:
            max_permitted = _ldj_tensor_conversions_allowed(bijector,
                                                            is_forward=False)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `inverse_log_det_jacobian` of {}'.format(bijector),
                    max_permissible=max_permitted):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ldj = bijector.inverse_log_det_jacobian(
                    ys + 0, event_ndims=event_ndims)
        grads = tape.gradient(ldj, wrt_vars)
        assert_no_none_grad(bijector, 'inverse_log_det_jacobian', wrt_vars,
                            grads)
示例#8
0
  def map_fn(x):
    """Internal function to flat_map over.

    Consumes a batch of input examples and produces a variable number of output
    examples.
    Args:
      x: a single example
    Returns:
      a tf.data.Dataset
    """
    partial = empty_example.copy()
    i = tf.zeros([], dtype=tf.int32)
    dynamic_batch_size = tf.shape(x[keys[0]])[0]
    outputs = {}
    for k in keys:
      outputs[k] = tf.TensorArray(
          tf.int32, size=0, dynamic_size=True, element_shape=[length[k]])
      outputs[k + '_position'] = tf.TensorArray(
          tf.int32, size=0, dynamic_size=True, element_shape=[length[k]])
    def cond_fn(i, partial, outputs):
      del partial, outputs
      return i < dynamic_batch_size
    def body_fn(i, partial, outputs):
      """Body function for while_loop.

      Args:
        i: integer scalar
        partial: dictionary of Tensor (partially-constructed example)
        outputs: dictionary of TensorArray
      Returns:
        A triple containing the new values of the inputs.
      """
      can_append = True
      one_example = {}
      for k in keys:
        val = tf.cast(x[k][i], tf.int32)
        val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))]
        one_example[k] = val
      for k in keys:
        can_append = tf.logical_and(
            can_append,
            tf.less_equal(
                tf.size(partial[k]) + tf.size(one_example[k]), length[k]))
      def false_fn():
        return write_packed_example(partial, outputs)
      def true_fn():
        return partial, outputs
      partial, outputs = tf.cond(can_append, true_fn, false_fn)
      new_partial = {}
      for k in keys:
        new_seq = one_example[k][:length[k]]
        new_seq_len = tf.size(new_seq)
        new_partial[k] = tf.concat([partial[k], new_seq], 0)
        new_partial[k + '_position'] = tf.concat(
            [partial[k + '_position'],
             tf.range(new_seq_len, dtype=tf.int32)], 0)
      partial = new_partial
      return i+1, partial, outputs

    i, partial, outputs = \
        tf.while_loop(
            cond_fn, body_fn, (i, partial, outputs),
            shape_invariants=(
                tf.TensorShape([]),
                {k: tf.TensorShape([None]) for k in keys_etc},
                {k: tf.TensorShape(None) for k in keys_etc},
            )
        )
    partial, outputs = write_packed_example(partial, outputs)
    packed = {k: outputs[k].stack() for k in keys_etc}
    for k in keys:
      packed[k + '_segmentation'] = (
          tf.cumsum(
              tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1) *
          tf.cast(tf.not_equal(packed[k], 0), tf.int32))
    return packed
示例#9
0
def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None):
    """Squeeze or expand last dimension if needed.

    1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1
    (using `remove_squeezable_dimensions`).
    2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
    from the new rank of `y_pred`.
    If `sample_weight` is scalar, it is kept scalar.

    This will use static shape if available. Otherwise, it will add graph
    operations, which could result in a performance hit.

    Args:
      y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
      y_true: Optional label `Tensor` whose dimensions match `y_pred`.
      sample_weight: Optional weight scalar or `Tensor` whose dimensions match
        `y_pred`.

    Returns:
      Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
      the last dimension squeezed,
      `sample_weight` could be extended by one dimension.
      If `sample_weight` is None, (y_pred, y_true) is returned.
    """
    y_pred_shape = y_pred.shape
    y_pred_rank = y_pred_shape.ndims
    if y_true is not None:

        # If sparse matrix is provided as `y_true`, the last dimension in
        # `y_pred` may be > 1. Eg: y_true = [0, 1, 2] (shape=(3,)), y_pred =
        # [[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]] (shape=(3, 3)) In
        # this case, we should not try to remove squeezable dimension.
        y_true_shape = y_true.shape
        y_true_rank = y_true_shape.ndims
        if (y_true_rank is not None) and (y_pred_rank is not None):
            # Use static rank for `y_true` and `y_pred`.
            if (y_pred_rank - y_true_rank != 1) or y_pred_shape[-1] == 1:
                y_true, y_pred = remove_squeezable_dimensions(y_true, y_pred)
        else:
            # Use dynamic rank.
            rank_diff = tf.rank(y_pred) - tf.rank(y_true)
            squeeze_dims = lambda: remove_squeezable_dimensions(y_true, y_pred)
            is_last_dim_1 = tf.equal(1, tf.shape(y_pred)[-1])
            maybe_squeeze_dims = lambda: tf.cond(
                is_last_dim_1, squeeze_dims, lambda: (y_true, y_pred)
            )
            y_true, y_pred = tf.cond(
                tf.equal(1, rank_diff), maybe_squeeze_dims, squeeze_dims
            )

    if sample_weight is None:
        return y_pred, y_true

    weights_shape = sample_weight.shape
    weights_rank = weights_shape.ndims
    if weights_rank == 0:  # If weights is scalar, do nothing.
        return y_pred, y_true, sample_weight

    if (y_pred_rank is not None) and (weights_rank is not None):
        # Use static rank.
        if weights_rank - y_pred_rank == 1:
            sample_weight = tf.squeeze(sample_weight, [-1])
        elif y_pred_rank - weights_rank == 1:
            sample_weight = tf.expand_dims(sample_weight, [-1])
        return y_pred, y_true, sample_weight

    # Use dynamic rank.
    weights_rank_tensor = tf.rank(sample_weight)
    rank_diff = weights_rank_tensor - tf.rank(y_pred)
    maybe_squeeze_weights = lambda: tf.squeeze(sample_weight, [-1])

    def _maybe_expand_weights():
        expand_weights = lambda: tf.expand_dims(sample_weight, [-1])
        return tf.cond(
            tf.equal(rank_diff, -1), expand_weights, lambda: sample_weight
        )

    def _maybe_adjust_weights():
        return tf.cond(
            tf.equal(rank_diff, 1), maybe_squeeze_weights, _maybe_expand_weights
        )

    # squeeze or expand last dim of `sample_weight` if its rank differs by 1
    # from the new rank of `y_pred`.
    sample_weight = tf.cond(
        tf.equal(weights_rank_tensor, 0),
        lambda: sample_weight,
        _maybe_adjust_weights,
    )
    return y_pred, y_true, sample_weight
示例#10
0
 def find_next_odd(v):
   v1 = v + 1
   while tf.equal(v1 % 2, 0):
     v1 = v1 + 1
   return v1
示例#11
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        if is_init and not dtype_util.is_integer(
                self.mixture_distribution.dtype):
            raise ValueError(
                '`mixture_distribution.dtype` ({}) is not over integers'.
                format(dtype_util.name(self.mixture_distribution.dtype)))

        if tensorshape_util.rank(
                self.mixture_distribution.event_shape) is not None:
            if tensorshape_util.rank(
                    self.mixture_distribution.event_shape) != 0:
                raise ValueError(
                    '`mixture_distribution` must have scalar `event_dim`s')
        elif self.validate_args:
            assertions += [
                assert_util.assert_equal(
                    tf.size(self.mixture_distribution.event_shape_tensor()),
                    0,
                    message=
                    '`mixture_distribution` must have scalar `event_dim`s'),
            ]

        # pylint: disable=protected-access
        mixture_dist_param = (self.mixture_distribution._probs
                              if self.mixture_distribution._logits is None else
                              self.mixture_distribution._logits)
        km = tf.compat.dimension_value(
            tensorshape_util.with_rank_at_least(mixture_dist_param.shape,
                                                1)[-1])
        kc = tf.compat.dimension_value(
            tensorshape_util.with_rank_at_least(
                self.components_distribution.batch_shape, 1)[-1])
        component_bst = None
        if km is not None and kc is not None:
            if km != kc:
                raise ValueError(
                    '`mixture_distribution` components ({}) does not '
                    'equal `components_distribution.batch_shape[-1]` '
                    '({})'.format(km, kc))
        elif self.validate_args:
            if km is None:
                mixture_dist_param = tf.convert_to_tensor(mixture_dist_param)
                km = tf.shape(mixture_dist_param)[-1]
            if kc is None:
                component_bst = self.components_distribution.batch_shape_tensor(
                )
                kc = component_bst[-1]
            assertions += [
                assert_util.assert_equal(
                    km,
                    kc,
                    message=(
                        '`mixture_distribution` components does not equal '
                        '`components_distribution.batch_shape[-1]`')),
            ]

        mdbs = self.mixture_distribution.batch_shape
        cdbs = tensorshape_util.with_rank_at_least(
            self.components_distribution.batch_shape, 1)[:-1]
        if (tensorshape_util.is_fully_defined(mdbs)
                and tensorshape_util.is_fully_defined(cdbs)):
            if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs:
                raise ValueError(
                    '`mixture_distribution.batch_shape` (`{}`) is not '
                    'compatible with `components_distribution.batch_shape` '
                    '(`{}`)'.format(tensorshape_util.as_list(mdbs),
                                    tensorshape_util.as_list(cdbs)))
        elif self.validate_args:
            if not tensorshape_util.is_fully_defined(mdbs):
                mixture_dist_param = tf.convert_to_tensor(mixture_dist_param)
                mdbs = tf.shape(mixture_dist_param)[:-1]
            if not tensorshape_util.is_fully_defined(cdbs):
                if component_bst is None:
                    component_bst = self.components_distribution.batch_shape_tensor(
                    )
                cdbs = component_bst[:-1]
            assertions += [
                assert_util.assert_equal(
                    distribution_utils.pick_vector(
                        tf.equal(tf.shape(mdbs)[0], 0), cdbs, mdbs),
                    cdbs,
                    message=(
                        '`mixture_distribution.batch_shape` is not '
                        'compatible with `components_distribution.batch_shape`'
                    ))
            ]

        return assertions
示例#12
0
 def grad(grad_ys):
   large_float_like_x = np.sqrt(np.finfo(x.dtype.as_numpy_dtype()).max)
   safe_grads = tf.where(
       tf.equal(x, 0), large_float_like_x, 0.5 * tf.math.rsqrt(x))
   return grad_ys * safe_grads
示例#13
0
 def single_step(features, labels):
   with tf.GradientTape() as tape:
     # Log summaries on the last step of the training loop to match
     # logging frequency of other scalar summaries.
     #
     # Notes:
     # 1. Summary ops on TPUs get outside compiled so they do not affect
     #    performance.
     # 2. Summaries are recorded only on replica 0. So effectively this
     #    summary would be written once per host when should_record == True.
     # 3. optimizer.iterations is incremented in the call to apply_gradients.
     #    So we use  `iterations + 1` here so that the step number matches
     #    those of scalar summaries.
     # 4. We intentionally run the summary op before the actual model
     #    training so that it can run in parallel.
     should_record = tf.equal((optimizer.iterations + 1) % steps_per_loop, 0)
     with tf.summary.record_if(should_record):
       # Only log augmented images for the first tower.
       tf.summary.image(
           'image', features[:, :, :, :3], step=optimizer.iterations + 1)
     projection_head_outputs, supervised_head_outputs = model(
         features, training=True)
     loss = None
     if projection_head_outputs is not None:
       outputs = projection_head_outputs
       con_loss, logits_con, labels_con = obj_lib.add_contrastive_loss(
           outputs,
           hidden_norm=FLAGS.hidden_norm,
           temperature=FLAGS.temperature,
           strategy=strategy)
       if loss is None:
         loss = con_loss
       else:
         loss += con_loss
       metrics.update_pretrain_metrics_train(contrast_loss_metric,
                                             contrast_acc_metric,
                                             contrast_entropy_metric,
                                             con_loss, logits_con,
                                             labels_con)
     if supervised_head_outputs is not None:
       outputs = supervised_head_outputs
       l = labels['labels']
       if FLAGS.train_mode == 'pretrain' and FLAGS.lineareval_while_pretraining:
         l = tf.concat([l, l], 0)
       sup_loss = obj_lib.add_supervised_loss(labels=l, logits=outputs)
       if loss is None:
         loss = sup_loss
       else:
         loss += sup_loss
       metrics.update_finetune_metrics_train(supervised_loss_metric,
                                             supervised_acc_metric, sup_loss,
                                             l, outputs)
     weight_decay = model_lib.add_weight_decay(
         model, adjust_per_optimizer=True)
     weight_decay_metric.update_state(weight_decay)
     loss += weight_decay
     total_loss_metric.update_state(loss)
     # The default behavior of `apply_gradients` is to sum gradients from all
     # replicas so we divide the loss by the number of replicas so that the
     # mean gradient is applied.
     loss = loss / strategy.num_replicas_in_sync
     logging.info('Trainable variables:')
     for var in model.trainable_variables:
       logging.info(var.name)
     grads = tape.gradient(loss, model.trainable_variables)
     optimizer.apply_gradients(zip(grads, model.trainable_variables))
  def pack_batch(x: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
    """Internal function to map over.

    Consumes a batch of input examples and produces a variable number of output
    examples.

    Args:
      x: a single example
    Returns:
      a tf.data.Dataset
    """
    keys = list(feature_lengths)
    partial = empty_example.copy()
    first_key, *_ = keys
    dynamic_batch_size = tf.shape(x[first_key])[0]
    outputs = {}
    for k in keys:
      outputs[k] = tf.TensorArray(
          tf.int32, size=0, dynamic_size=True,
          element_shape=[feature_lengths[k]])
      outputs[k + "_positions"] = tf.TensorArray(
          tf.int32, size=0, dynamic_size=True,
          element_shape=[feature_lengths[k]])

    for i in tf.range(0, dynamic_batch_size):
      tf.autograph.experimental.set_loop_options(
          shape_invariants=[
              (partial, {k: tf.TensorShape([None]) for k in keys_etc}),
              (outputs, {k: tf.TensorShape(None) for k in keys_etc})]
      )

      can_append = True
      one_example = {}
      for k in keys:
        val = tf.cast(x[k][i], tf.int32)
        val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))]
        one_example[k] = val
      for k in keys:
        can_append = tf.logical_and(
            can_append,
            tf.less_equal(
                tf.size(partial[k]) + tf.size(one_example[k]),
                feature_lengths[k]))

      if not can_append:
        partial, outputs = _write_packed_example(partial, outputs)

      new_partial = {}
      for k in keys:
        new_seq = one_example[k][:feature_lengths[k]]
        new_seq_len = tf.size(new_seq)
        new_partial[k] = tf.concat([partial[k], new_seq], 0)
        new_partial[k + "_positions"] = tf.concat(
            [partial[k + "_positions"],
             tf.range(new_seq_len, dtype=tf.int32)], 0)
      partial = new_partial

    partial, outputs = _write_packed_example(partial, outputs)
    packed = {k: outputs[k].stack() for k in keys_etc}
    for k in keys:
      packed[k + "_segment_ids"] = (
          tf.cumsum(
              tf.cast(tf.equal(packed[k + "_positions"], 0), tf.int32), axis=1)
          * tf.cast(tf.not_equal(packed[k], 0), tf.int32))
    return packed
def main(_):
    # TODO(anthonyjliu): Enable debugger from flags
    if FLAGS.debug and FLAGS.tensorboard_debug_address:
        raise ValueError(
            "The --debug and --tensorboard_debug_address flags are mutually "
            "exclusive.")
    if FLAGS.debug:
        tf.debugging.enable_check_numerics()
    elif FLAGS.tensorboard_debug_address:
        raise NotImplementedError(
            "Tensorboard Debugger Plugin support for debug_mnist_v2 is not "
            "implemented yet")

    # Import data
    if FLAGS.fake_data:
        imgs = tf.random.uniform(maxval=256,
                                 shape=(1000, 28, 28),
                                 dtype=tf.int32)
        labels = tf.random.uniform(maxval=10, shape=(1000, ), dtype=tf.int32)
        mnist_train = imgs, labels
        mnist_test = imgs, labels
    else:
        mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()

    @tf.function
    def format_example(imgs, labels):
        """Formats each training and test example to work with our model."""
        imgs = tf.reshape(imgs, [-1, 28 * 28])
        imgs = tf.cast(imgs, tf.float32) / 255.0
        labels = tf.one_hot(labels, depth=10, dtype=tf.float32)
        return imgs, labels

    train_ds = tf.data.Dataset.from_tensor_slices(mnist_train).shuffle(
        FLAGS.train_batch_size * FLAGS.max_steps,
        seed=RAND_SEED).batch(FLAGS.train_batch_size)
    train_ds = train_ds.map(format_example)

    test_ds = tf.data.Dataset.from_tensor_slices(mnist_test).repeat().batch(
        len(mnist_test[0]))
    test_ds = test_ds.map(format_example)

    def get_dense_weights(input_dim, output_dim):
        """Initializes the parameters for a single dense layer."""
        initial_kernel = tf.keras.initializers.TruncatedNormal(mean=0.0,
                                                               stddev=0.1,
                                                               seed=RAND_SEED)
        kernel = tf.Variable(initial_kernel([input_dim, output_dim]))
        bias = tf.Variable(tf.constant(0.1, shape=[output_dim]))

        return kernel, bias

    @tf.function
    def dense_layer(weights, input_tensor, act=tf.nn.relu):
        """Runs the forward computation for a single dense layer."""
        kernel, bias = weights
        preactivate = tf.matmul(input_tensor, kernel) + bias

        activations = act(preactivate)
        return activations

    # init model
    hidden = get_dense_weights(IMAGE_SIZE**2, HIDDEN_SIZE)
    logits = get_dense_weights(HIDDEN_SIZE, NUM_LABELS)
    variables = hidden + logits

    @tf.function
    def model(x):
        """Feed forward function of the model.

    Args:
      x: a (?, 28*28) tensor consisting of the feature inputs for a batch of
        examples.

    Returns:
      A (?, 10) tensor containing the class scores for each example.
    """
        hidden_act = dense_layer(hidden, x)
        logits_act = dense_layer(logits, hidden_act, tf.identity)
        y = tf.nn.softmax(logits_act)
        return y

    @tf.function
    def loss(logits, labels):
        """Calculates cross entropy loss."""
        diff = -(labels * tf.math.log(logits))
        loss = tf.reduce_mean(diff)
        return loss

    train_batches = iter(train_ds)
    test_batches = iter(test_ds)
    optimizer = tf.optimizers.Adam(learning_rate=FLAGS.learning_rate)
    for i in range(FLAGS.max_steps):
        x_train, y_train = next(train_batches)
        x_test, y_test = next(test_batches)

        # Train Step
        with tf.GradientTape() as tape:
            y = model(x_train)
            loss_val = loss(y, y_train)
        grads = tape.gradient(loss_val, variables)

        optimizer.apply_gradients(zip(grads, variables))

        # Evaluation Step
        y = model(x_test)
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_test, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        print("Accuracy at step %d: %s" % (i, accuracy.numpy()))
示例#16
0
def interpolate1d(x, values, tangents):
    r"""Perform cubic hermite spline interpolation on a 1D spline.

  The x coordinates of the spline knots are at [0 : 1 : len(values)-1].
  Queries outside of the range of the spline are computed using linear
  extrapolation. See https://en.wikipedia.org/wiki/Cubic_Hermite_spline
  for details, where "x" corresponds to `x`, "p" corresponds to `values`, and
  "m" corresponds to `tangents`.

  Args:
    x: A tensor of any size of single or double precision floats containing the
      set of values to be used for interpolation into the spline.
    values: A vector of single or double precision floats containing the value
      of each knot of the spline being interpolated into. Must be the same
      length as `tangents` and the same type as `x`.
    tangents: A vector of single or double precision floats containing the
      tangent (derivative) of each knot of the spline being interpolated into.
      Must be the same length as `values` and the same type as `x`.

  Returns:
    The result of interpolating along the spline defined by `values`, and
    `tangents`, using `x` as the query values. Will be the same length and type
    as `x`.
  """
    # `values` and `tangents` must have the same type as `x`.
    tf.debugging.assert_type(values, x.dtype)
    tf.debugging.assert_type(tangents, x.dtype)
    float_dtype = x.dtype
    assert_ops = [
        # `values` must be a vector.
        tf.Assert(tf.equal(tf.rank(values), 1), [tf.shape(values)]),
        # `tangents` must be a vector.
        tf.Assert(tf.equal(tf.rank(tangents), 1), [tf.shape(values)]),
        # `values` and `tangents` must have the same length.
        tf.Assert(
            tf.equal(tf.shape(values)[0],
                     tf.shape(tangents)[0]),
            [tf.shape(values)[0], tf.shape(tangents)[0]]),
    ]
    with tf.control_dependencies(assert_ops):
        # Find the indices of the knots below and above each x.
        x_lo = tf.cast(
            tf.floor(
                tf.clip_by_value(x, 0.,
                                 tf.cast(tf.shape(values)[0] - 2,
                                         float_dtype))), tf.int32)
        x_hi = x_lo + 1

        # Compute the relative distance between each `x` and the knot below it.
        t = x - tf.cast(x_lo, float_dtype)

        # Compute the cubic hermite expansion of `t`.
        t_sq = tf.square(t)
        t_cu = t * t_sq
        h01 = -2. * t_cu + 3. * t_sq
        h00 = 1. - h01
        h11 = t_cu - t_sq
        h10 = h11 - t_sq + t

        # Linearly extrapolate above and below the extents of the spline for all
        # values.
        value_before = tangents[0] * t + values[0]
        value_after = tangents[-1] * (t - 1.) + values[-1]

        # Cubically interpolate between the knots below and above each query point.
        neighbor_values_lo = tf.gather(values, x_lo)
        neighbor_values_hi = tf.gather(values, x_hi)
        neighbor_tangents_lo = tf.gather(tangents, x_lo)
        neighbor_tangents_hi = tf.gather(tangents, x_hi)
        value_mid = (neighbor_values_lo * h00 + neighbor_values_hi * h01 +
                     neighbor_tangents_lo * h10 + neighbor_tangents_hi * h11)

        # Return the interpolated or extrapolated values for each query point,
        # depending on whether or not the query lies within the span of the spline.
        return tf.where(t < 0., value_before,
                        tf.where(t > 1., value_after, value_mid))
示例#17
0
 def _maybe_expand_weights():
     expand_weights = lambda: tf.expand_dims(sample_weight, [-1])
     return tf.cond(
         tf.equal(rank_diff, -1), expand_weights, lambda: sample_weight
     )
示例#18
0
def decode_fn(value,
              data_aug=False,
              max_num_points=245760,
              max_num_bboxes=100,
              class_id=1,
              difficulty=1,
              pillar_map_size=(256, 256),
              pillar_map_range=(75.2, 75.2)):
    """Decode function."""
    tensor_dict = waymo_decoder.decode_tf_example(
        serialized_example=value, features=waymo_decoder.FEATURE_SPEC)
    frame_valid = tensor_dict['frame_valid']

    points_xyz = tensor_dict['lidars']['points_xyz']
    points_feature = tensor_dict['lidars']['points_feature']
    points_mask = tensor_dict['lidars']['points_mask']

    all_points_xyz = tensor_dict['lidars']['all_points_xyz']
    all_points_xyz_transformed = (
        tensor_dict['lidars']['all_points_xyz_transformed'])
    all_points_feature = tensor_dict['lidars']['all_points_feature']
    all_points_mask = tensor_dict['lidars']['all_points_mask']

    bboxes = tensor_dict['objects']['box']
    bboxes_label = tensor_dict['objects']['label']
    bboxes_speed = tensor_dict['objects']['speed']
    bboxes_difficulty = tensor_dict['objects']['combined_difficulty_level']
    bboxes_detection_difficulty = (
        tensor_dict['objects']['combined_difficulty_level'])

    bboxes_difficulty = bboxes_difficulty <= difficulty
    bboxes_mask = tf.equal(bboxes_label, class_id)
    bboxes_mask = tf.math.logical_and(bboxes_mask, bboxes_difficulty)
    bboxes_mask = tf.cast(bboxes_mask, dtype=tf.dtypes.float32)

    num_valid_bboxes = tf_util.get_shape(bboxes)[0]
    bboxes_index = tf.math.top_k(bboxes_mask,
                                 k=tf.math.minimum(max_num_bboxes,
                                                   num_valid_bboxes))[1]
    bboxes_mask = tf.gather(bboxes_mask, bboxes_index)
    bboxes_label = tf.gather(bboxes_label, bboxes_index)
    bboxes = tf.gather(bboxes, bboxes_index)
    bboxes_speed = tf.gather(bboxes_speed, bboxes_index)

    bboxes = tf_util.pad_or_trim_to(bboxes, [max_num_bboxes, 7])
    bboxes_label = tf_util.pad_or_trim_to(bboxes_label, [max_num_bboxes])
    bboxes_speed = tf_util.pad_or_trim_to(bboxes_speed, [max_num_bboxes, 2])
    bboxes_difficulty = tf_util.pad_or_trim_to(bboxes_difficulty,
                                               [max_num_bboxes])
    bboxes_mask = tf_util.pad_or_trim_to(bboxes_mask, [max_num_bboxes])

    if data_aug:
        (points_xyz, points_mask, bboxes, all_points_xyz,
         all_points_xyz_transformed, all_points_mask) = augment(
             points_xyz=points_xyz,
             points_mask=points_mask,
             bboxes=bboxes,
             all_points_xyz=all_points_xyz,
             all_points_xyz_transformed=all_points_xyz_transformed,
             all_points_mask=all_points_mask)

    (pillar_map_xyz, pillar_map_bboxes, pillar_map_bboxes_label,
     pillar_map_if_in_bboxes, pillar_map_centerness,
     pillar_map_bboxes_index) = (assign_bboxes(
         pillar_map_size=pillar_map_size,
         pillar_map_range=pillar_map_range,
         bboxes=bboxes,
         bboxes_label=bboxes_label,
         bboxes_mask=bboxes_mask))

    pillar_map_xyz = tf.reshape(pillar_map_xyz, [-1, 3])
    pillar_map_bboxes = tf.reshape(pillar_map_bboxes, [-1, 7])
    pillar_map_bboxes_label = tf.reshape(pillar_map_bboxes_label, [-1])
    pillar_map_if_in_bboxes = tf.reshape(pillar_map_if_in_bboxes, [-1])
    pillar_map_centerness = tf.reshape(pillar_map_centerness, [-1])
    pillar_map_bboxes_index = tf.reshape(pillar_map_bboxes_index, [-1])

    all_points_mask = tf.expand_dims(all_points_mask, axis=-1)

    all_points = tf.concat([
        all_points_xyz, all_points_xyz_transformed, all_points_feature,
        all_points_mask
    ],
                           axis=-1)

    num_frames, num_points, num_features = tf_util.get_shape(all_points)
    all_points = tf.reshape(all_points,
                            [num_frames * num_points, num_features])

    return {
        'points_xyz': points_xyz,
        'points_feature': points_feature,
        'points_mask': points_mask,
        'bboxes': bboxes,
        'bboxes_label': bboxes_label,
        'bboxes_mask': bboxes_mask,
        'bboxes_speed': bboxes_speed,
        'pillar_map_xyz': pillar_map_xyz,
        'pillar_map_bboxes': pillar_map_bboxes,
        'pillar_map_if_in_bboxes': pillar_map_if_in_bboxes,
    }
示例#19
0
 def _maybe_adjust_weights():
     return tf.cond(
         tf.equal(rank_diff, 1), maybe_squeeze_weights, _maybe_expand_weights
     )
示例#20
0
def _at_least_x_are_equal(a, b, x):
    """At least `x` of `a` and `b` `Tensors` are equal."""
    match = tf.equal(a, b)
    match = tf.cast(match, tf.int32)
    return tf.greater_equal(tf.reduce_sum(match), x)
示例#21
0
def remove_squeezable_dimensions(
    labels, predictions, expected_rank_diff=0, name=None
):
    """Squeeze last dim if ranks differ from expected by exactly 1.

    In the common case where we expect shapes to match, `expected_rank_diff`
    defaults to 0, and we squeeze the last dimension of the larger rank if they
    differ by 1.

    But, for example, if `labels` contains class IDs and `predictions` contains
    1 probability per class, we expect `predictions` to have 1 more dimension
    than `labels`, so `expected_rank_diff` would be 1. In this case, we'd
    squeeze `labels` if `rank(predictions) - rank(labels) == 0`, and
    `predictions` if `rank(predictions) - rank(labels) == 2`.

    This will use static shape if available. Otherwise, it will add graph
    operations, which could result in a performance hit.

    Args:
      labels: Label values, a `Tensor` whose dimensions match `predictions`.
      predictions: Predicted values, a `Tensor` of arbitrary dimensions.
      expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`.
      name: Name of the op.

    Returns:
      Tuple of `labels` and `predictions`, possibly with last dim squeezed.
    """
    with backend.name_scope(name or "remove_squeezable_dimensions"):
        if not tf_utils.is_tensor_or_extension_type(predictions):
            predictions = tf.convert_to_tensor(predictions)
        if not tf_utils.is_tensor_or_extension_type(labels):
            labels = tf.convert_to_tensor(labels)
        predictions_shape = predictions.shape
        predictions_rank = predictions_shape.ndims
        labels_shape = labels.shape
        labels_rank = labels_shape.ndims
        if (labels_rank is not None) and (predictions_rank is not None):
            # Use static rank.
            rank_diff = predictions_rank - labels_rank
            if rank_diff == expected_rank_diff + 1 and predictions_shape.dims[
                -1
            ].is_compatible_with(1):
                predictions = tf.squeeze(predictions, [-1])
            elif rank_diff == expected_rank_diff - 1 and labels_shape.dims[
                -1
            ].is_compatible_with(1):
                labels = tf.squeeze(labels, [-1])
            return labels, predictions

        # Use dynamic rank.
        rank_diff = tf.rank(predictions) - tf.rank(labels)
        if (predictions_rank is None) or (
            predictions_shape.dims[-1].is_compatible_with(1)
        ):
            predictions = tf.cond(
                tf.equal(expected_rank_diff + 1, rank_diff),
                lambda: tf.squeeze(predictions, [-1]),
                lambda: predictions,
            )
        if (labels_rank is None) or (
            labels_shape.dims[-1].is_compatible_with(1)
        ):
            labels = tf.cond(
                tf.equal(expected_rank_diff - 1, rank_diff),
                lambda: tf.squeeze(labels, [-1]),
                lambda: labels,
            )
        return labels, predictions
示例#22
0
 def nonzero(x):
     return tf.where(tf.equal(x, 0), 1e-6, x)
示例#23
0
 def _is_equal_or_close(self, a, b):
     if dtype_util.is_integer(self.outcomes.dtype):
         return tf.equal(a, b)
     return tf.abs(a - b) < self._atol + self._rtol * tf.abs(b)
示例#24
0
def _brent_loop_body(state, params, constants):
    """Performs one iteration of the Brent root-finding algorithm.

  Args:
    state: A Python `_BrentSearchState` namedtuple.
    params: A Python `_BrentSearchParams` namedtuple.
    constants: A Python `_BrentSearchConstants` namedtuple.

  Returns:
    The `Tensor`s to use for the next iteration of the algorithm.
  """

    best_estimate = state.best_estimate
    last_estimate = state.last_estimate
    contrapoint = state.contrapoint
    value_at_best_estimate = state.value_at_best_estimate
    value_at_last_estimate = state.value_at_last_estimate
    value_at_contrapoint = state.value_at_contrapoint
    step_to_best_estimate = state.step_to_best_estimate
    step_to_last_estimate = state.step_to_last_estimate
    num_iterations = state.num_iterations
    finished = state.finished

    # If the root is between the last two estimates, use the worst of the two
    # as new contrapoint. Adjust step sizes accordingly.
    replace_contrapoint = ~finished & (
        value_at_last_estimate * value_at_best_estimate < constants.zero_value)

    contrapoint = tf.where(replace_contrapoint, last_estimate, contrapoint)
    value_at_contrapoint = tf.where(replace_contrapoint,
                                    value_at_last_estimate,
                                    value_at_contrapoint)

    step_to_last_estimate = tf.where(replace_contrapoint,
                                     best_estimate - last_estimate,
                                     step_to_last_estimate)
    step_to_best_estimate = tf.where(replace_contrapoint,
                                     step_to_last_estimate,
                                     step_to_best_estimate)

    # If the contrapoint is a better guess than the current root estimate, swap
    # them. Also, replace the worst of the two with the current contrapoint.
    replace_best_estimate = tf.where(
        finished, constants.false,
        tf.math.abs(value_at_contrapoint) <
        tf.math.abs(value_at_best_estimate))

    last_estimate = tf.where(replace_best_estimate, best_estimate,
                             last_estimate)
    best_estimate = tf.where(replace_best_estimate, contrapoint, best_estimate)
    contrapoint = tf.where(replace_best_estimate, last_estimate, contrapoint)

    value_at_last_estimate = tf.where(replace_best_estimate,
                                      value_at_best_estimate,
                                      value_at_last_estimate)
    value_at_best_estimate = tf.where(replace_best_estimate,
                                      value_at_contrapoint,
                                      value_at_best_estimate)
    value_at_contrapoint = tf.where(replace_best_estimate,
                                    value_at_last_estimate,
                                    value_at_contrapoint)

    # Compute the tolerance used to control root search at the current position
    # and the step size corresponding to the bisection method.
    root_tolerance = 0.5 * (
        params.absolute_root_tolerance +
        params.relative_root_tolerance * tf.math.abs(best_estimate))
    bisection_step = 0.5 * (contrapoint - best_estimate)

    # Mark the search as finished if either:
    # 1. the maximum number of iterations has been reached;
    # 2. the desired tolerance has been reached (even if no root was found);
    # 3. the current root estimate is good enough.
    # Using zero as `function_tolerance` will check for exact roots and match
    # both Brent's original algorithm and the SciPy implementation.
    finished |= (num_iterations >= params.max_iterations) | (
        tf.math.abs(bisection_step) <
        root_tolerance) | (~tf.math.is_finite(value_at_best_estimate)) | (
            tf.math.abs(value_at_best_estimate) <= params.function_tolerance)

    # Determine whether interpolation or extrapolation are worth performing at
    # the current position.
    compute_short_step = tf.where(
        finished, constants.false,
        (root_tolerance < tf.math.abs(step_to_last_estimate)) &
        (tf.math.abs(value_at_best_estimate) <
         tf.math.abs(value_at_last_estimate)))

    short_step = tf.where(
        compute_short_step,
        tf.where(
            # The contrapoint cannot be equal to the current root estimate since
            # they have opposite signs. However, it may be equal to the previous
            # estimate.
            tf.equal(last_estimate, contrapoint),
            # If so, use the secant method to avoid a division by zero which
            # would occur if using extrapolation.
            _secant_step(best_estimate, last_estimate, value_at_best_estimate,
                         value_at_last_estimate),
            # Pass values of the objective function as x values, and root
            # estimates as y values in order to perform *inverse* extrapolation.
            _quadratic_interpolation_step(value_at_best_estimate,
                                          value_at_last_estimate,
                                          value_at_contrapoint, best_estimate,
                                          last_estimate, contrapoint)),
        # Default to zero if using bisection.
        constants.zero)

    # Use the step calculated above if both:
    # 1. step size < |previous step size|
    # 2. step size < 3/4 * |contrapoint - current root estimate|
    # Ensure that `short_step` was calculated by guarding the calculation with
    # `compute_short_step`.
    use_short_step = tf.where(
        compute_short_step, 2 * tf.math.abs(short_step) < tf.minimum(
            3 * tf.math.abs(bisection_step) - root_tolerance,
            tf.math.abs(step_to_last_estimate)), constants.false)

    # Revert to bisection when not using `short_step`.
    step_to_last_estimate = tf.where(use_short_step, step_to_best_estimate,
                                     bisection_step)
    step_to_best_estimate = tf.where(
        finished, constants.zero,
        tf.where(use_short_step, short_step, bisection_step))

    # Update the previous and current root estimates.
    last_estimate = tf.where(finished, last_estimate, best_estimate)
    best_estimate += tf.where(
        finished, constants.zero,
        tf.where(root_tolerance < tf.math.abs(step_to_best_estimate),
                 step_to_best_estimate,
                 tf.where(bisection_step > 0, root_tolerance,
                          -root_tolerance)))

    value_at_last_estimate = tf.where(finished, value_at_last_estimate,
                                      value_at_best_estimate)
    value_at_best_estimate = tf.where(finished, value_at_best_estimate,
                                      params.objective_fn(best_estimate))

    num_iterations = tf.where(finished, num_iterations, num_iterations + 1)

    return [
        _BrentSearchState(best_estimate=best_estimate,
                          last_estimate=last_estimate,
                          contrapoint=contrapoint,
                          value_at_best_estimate=value_at_best_estimate,
                          value_at_last_estimate=value_at_last_estimate,
                          value_at_contrapoint=value_at_contrapoint,
                          step_to_best_estimate=step_to_best_estimate,
                          step_to_last_estimate=step_to_last_estimate,
                          num_iterations=num_iterations,
                          finished=finished)
    ]
示例#25
0
def _resample_using_log_points(log_probs, sample_shape, log_points, name=None):
    """Resample from `log_probs` using supplied points in interval `[0, 1]`."""

    # We divide up the unit interval [0, 1] according to the provided
    # probability distributions using `cumulative_logsumexp`.
    # At the end of each division we place a 'marker'.
    # We use points on the unit interval supplied by caller.
    # We sort the combination of points and markers. The number
    # of points between the markers defining a division gives the number
    # of samples we require in that division.
    # For example, suppose `probs` is `[0.2, 0.3, 0.5]`.
    # We divide up `[0, 1]` using 3 markers:
    #
    #     |     |          |
    # 0.  0.2   0.5        1.0  <- markers
    #
    # Suppose we are given four points: [0.1, 0.25, 0.9, 0.75]
    # After sorting the combination we get:
    #
    # 0.1  0.25     0.75 0.9    <- points
    #  *  | *   |    *    *|
    # 0.   0.2 0.5         1.0  <- markers
    #
    # We have one sample in the first category, one in the second and
    # two in the last.
    #
    # All of these computations are carried out in batched form.

    with tf.name_scope(name or 'resample_using_log_points') as name:
        points_shape = ps.shape(log_points)
        batch_shape, [num_markers] = ps.split(ps.shape(log_probs),
                                              num_or_size_splits=[-1, 1])

        # `working_shape` specifies the total number of events
        # we will be generating.
        working_shape = ps.concat([sample_shape, batch_shape], axis=0)
        # `markers_shape` is the shape of the markers we temporarily insert.
        markers_shape = ps.concat([working_shape, [num_markers]], axis=0)

        markers = ps.concat([
            tf.ones(markers_shape, dtype=tf.int32),
            tf.zeros(points_shape, dtype=tf.int32)
        ],
                            axis=-1)
        log_marker_positions = tf.broadcast_to(
            log_cumsum_exp(log_probs, axis=-1), markers_shape)
        log_markers_and_points = ps.concat([log_marker_positions, log_points],
                                           axis=-1)
        # Stable sort is used to ensure that no points get sorted between
        # markers that have zero distance between them. This ensures that
        # there will never be a sample drawn whose probability is intended
        # to be zero even when a point falls on the edge of the
        # corresponding zero-width bucket.
        indices = tf.argsort(log_markers_and_points, axis=-1, stable=True)
        sorted_markers = tf.gather_nd(
            markers,
            indices[..., tf.newaxis],
            batch_dims=(ps.rank_from_shape(sample_shape) +
                        ps.rank_from_shape(batch_shape)))
        markers_and_samples = ps.cast(tf.cumsum(sorted_markers, axis=-1),
                                      dtype=tf.int32)
        markers_and_samples = tf.math.minimum(markers_and_samples,
                                              num_markers - np.int32(1))

        # Collect up samples, omitting markers.
        samples_mask = tf.equal(sorted_markers, 0)

        # The following block of code is equivalent to
        # `samples = markers_and_samples[samples_mask]` however boolean mask
        # indices are not supported by XLA.
        # Instead we use `argsort` to pick out the top `num_samples`
        # elements of `markers_and_samples` when sorted using `samples_mask`
        # as key.
        num_samples = points_shape[-1]
        sample_locations = tf.argsort(ps.cast(samples_mask, dtype=tf.int32),
                                      direction='DESCENDING',
                                      stable=True)
        samples = tf.gather_nd(markers_and_samples,
                               sample_locations[..., :num_samples, tf.newaxis],
                               batch_dims=(ps.rank_from_shape(sample_shape) +
                                           ps.rank_from_shape(batch_shape)))

        return tf.reshape(samples, points_shape)
def _calculate_spline_coeffs(x_data, y_data):
    """Calculates the coefficients for the spline interpolation.

  These are the values of the second derivative of the spline at `x_data`.
  See p.548 of [1].

  Below is an outline of the function when number of observations if equal to 7.
  The coefficients are obtained by building and solving a tridiagonal linear
  system of equations with symmetric matrix

   w2,  dx2,   0,   0,   0
   dx2,  w3, dx3,   0,   0
   0,  dx3,   w4, dx4,   0
   0,    0,  dx4,  w5, dx5
   0,    0,    0, dx5,  w6

   where:
   wn = 2 * (x_data[n-2] + x_data[n-1])
   dxn = x_data[n-1] - x_data[n-2]

   and the right hand side of the equation is:
   [[3*( (d2-d1)/X1 - (d1-d0)/x0],
    [3*( (d3-d2)/X2 - (d2-d1)/x1],
    ...
   ]

   with di = y_data[..., i]

   Solve for `spline_coeffs`, so that  matrix * spline_coeffs = rhs

   the solution is the `spline_coeffs` parameter of the spline equation:

   y_pred = a(spline_coeffs) * t^3 + b(spline_coeffs) * t^2
            + c(spline_coeffs) * t + d(spline_coeffs)
   with t being the proportion of the difference between the x value of
   the spline used and the nx_value of the next spline:

   t = (x_values - x_data[:,n]) / (x_data[:,n+1]-x_data[:,n])

   and `a`, `b`, `c`, and `d` are functions of `spline_coeffs` and `x_data` and
   are provided in the `interpolate` function.

  ## References:
  [1]: R. Sedgewick, Algorithms in C, 1990, p. 545-550.
    Link: http://index-of.co.uk/Algorithms/Algorithms%20in%20C.pdf

  Args:
    x_data: A real `Tensor` of shape `[..., num_points]` containing
      X-coordinates of points to fit the splines to. The values have to
      be monotonically non-decreasing along the last dimension.
    y_data: A `Tensor` of the same shape and `dtype` as `x_data` containing
      Y-coordinates of points to fit the splines to.

  Returns:
     A `Tensor` of the same shape and `dtype` as `x_data`. Represents the
     spline coefficients for the cubic spline interpolation.
  """

    # `dx` is the distances between the x points. It is 1 element shorter than
    # `x_data`
    dx = x_data[..., 1:] - x_data[..., :-1]

    # `diag_values` are the diagonal values 2 * (x_data[i+1] - x_data[i-1])
    # its length 2 shorter

    diag_values = 2.0 * (x_data[..., 2:] - x_data[..., :-2])
    superdiag = dx[..., 1:]
    subdiag = dx[..., :-1]

    corr_term = tf.logical_or(tf.equal(superdiag, 0), tf.equal(subdiag, 0))
    diag_values_corr = tf.where(corr_term, tf.ones_like(diag_values),
                                diag_values)
    superdiag_corr = tf.where(tf.equal(subdiag, 0), tf.zeros_like(superdiag),
                              superdiag)
    subdiag_corr = tf.where(tf.equal(superdiag, 0), tf.zeros_like(subdiag),
                            subdiag)
    diagonals = tf.stack([superdiag_corr, diag_values_corr, subdiag_corr],
                         axis=-2)

    # determine the rhs of the equation
    dd = (y_data[..., 1:] - y_data[..., :-1]) / dx
    dd = tf.where(tf.equal(dx, 0), tf.zeros_like(dd), dd)
    # rhs is a column vector:
    # [[-3((y1-y0)/dx0 - (y2-y1)/dx0], ...]
    rhs = -3 * (dd[..., :-1] - dd[..., 1:])
    rhs = tf.where(corr_term, tf.zeros_like(rhs), rhs)
    # Partial pivoting is unnecessary since the matrix is diagonally dominant.
    spline_coeffs = tf.linalg.tridiagonal_solve(diagonals,
                                                rhs,
                                                partial_pivoting=False)
    # Reshape `spline_coeffs`
    zero = tf.zeros_like(dx[..., :1], dtype=x_data.dtype)
    spline_coeffs = tf.concat([zero, spline_coeffs, zero], axis=-1)
    return spline_coeffs
示例#27
0
def softquantiles(x,
                  quantiles,
                  quantile_width=None,
                  axis=-1,
                  may_squeeze=True,
                  **kwargs):
    """Computes soft quantiles via optimal transport.

  This operator takes advantage of the fact that an exhaustive softsort is not
  required to recover a single quantile. Instead, one can transport all
  input values in x onto only 3 weighted values. Target weights are adjusted so
  that those values in x that are transported to the middle value in the target
  vector y correspond to those concentrating around the quantile of interest.

  This idea generalizes to more quantiles, interleaving small weights on the
  quantile indices and bigger weights in between, corresponding to the gap from
  one desired quantile to the next one.

  Args:
   x: Tensor<float> of any shape.
   quantiles: list<float> the quantiles to be returned. It can also be a single
     float.
   quantile_width: (float) mass given to the bucket supposed to attract points
     whose value concentrate around the desired quantile value. Bigger width
     means that we allow the soft quantile to be a mixture of more points
     further away from the quantile. If None, the width is set at 1/n where n is
     the number of values considered (the size along the 'axis').
   axis: (int) the axis along which to compute the quantile.
   may_squeeze: (bool) should we squeeze the output tensor in case of a single
     quantile.
   **kwargs: see SoftQuantilizer for possible extra parameters.

  Returns:
    A Tensor<float> similar to the input tensor, but the axis dimension is
    replaced by the number of quantiles specified in the quantiles list.
    Hence, if only a quantile is requested (quantiles is a float) only one value
    in that axis is returned. When several quantiles are requested, the tensor
    will have that many values in that axis.

  Raises:
    tf.errors.InvalidArgumentError when the quantiles and quantile width are not
    correct, namely quantiles are either not in sorted order or the
    quantile_width is too large.
  """
    if isinstance(quantiles, float):
        quantiles = [quantiles]
    quantiles = tf.constant(quantiles, tf.float32)

    # Preprocesses submitted quantiles to check that they satisfy elementary
    # constraints.
    valid_quantiles = tf.boolean_mask(
        quantiles, tf.logical_and(quantiles > 0.0, quantiles < 1.0))
    num_quantiles = tf.shape(valid_quantiles)[0]

    # Includes values on both ends of [0,1].
    extended_quantiles = tf.concat([[0.0], valid_quantiles, [1.0]], axis=0)

    # Builds filler_weights in between the target quantiles.
    filler_weights = extended_quantiles[1:] - extended_quantiles[:-1]
    if quantile_width is None:
        quantile_width = tf.reduce_min(
            tf.concat([
                filler_weights,
                [1.0 / tf.cast(tf.shape(x)[axis], dtype=x.dtype)]
            ],
                      axis=0))

    # Takes into account quantile_width in the definition of weights
    shift = -tf.ones(tf.shape(filler_weights), dtype=x.dtype)
    shift = shift + 0.5 * (tf.one_hot(0, num_quantiles + 1) +
                           tf.one_hot(num_quantiles, num_quantiles + 1))
    filler_weights = filler_weights + quantile_width * shift

    assert_op = tf.Assert(tf.reduce_all(filler_weights >= 0.0),
                          [filler_weights])
    with tf.control_dependencies([assert_op]):
        # Adds one more value to have tensors of the same shape to interleave them.
        quantile_weights = tf.ones(num_quantiles + 1) * quantile_width

        # Interleaves the filler_weights with the quantile weights.
        weights = tf.reshape(
            tf.stack([filler_weights, quantile_weights], axis=1), (-1, ))[:-1]

        # Sends only the positive weights to the softsort operator.
        positive_weights = tf.boolean_mask(weights, weights > 0.0)
        all_quantiles = softsort(x,
                                 direction='ASCENDING',
                                 axis=axis,
                                 target_weights=positive_weights,
                                 **kwargs)

        # Recovers the indices corresponding to the desired quantiles.
        odds = tf.math.floormod(tf.range(weights.shape[0], dtype=tf.float32),
                                2)
        positives = tf.cast(weights > 0.0, tf.float32)
        indices = tf.cast(tf.math.cumsum(positives) * odds, dtype=tf.int32)
        indices = tf.boolean_mask(indices, indices > 0) - 1
        result = tf.gather(all_quantiles, indices, axis=axis)

        # In the specific case where we want a single quantile, squeezes the
        # quantile dimension.
        can_squeeze = tf.equal(tf.shape(result)[axis], 1)
        if tf.math.logical_and(can_squeeze, may_squeeze):
            result = tf.squeeze(result, axis=axis)
        return result
示例#28
0
 def _diag(v, k):
   return utils.cond(
       tf.equal(tf.size(v), 0),
       lambda: tf.zeros([abs(k), abs(k)], dtype=v.dtype),
       lambda: tf.linalg.diag(v, k=k))
示例#29
0
def update_confusion_matrix_variables(variables_to_update,
                                      y_true,
                                      y_pred,
                                      thresholds,
                                      top_k=None,
                                      class_id=None,
                                      sample_weight=None,
                                      multi_label=False,
                                      label_weights=None,
                                      thresholds_distributed_evenly=False):
    """Returns op to update the given confusion matrix variables.

  For every pair of values in y_true and y_pred:

  true_positive: y_true == True and y_pred > thresholds
  false_negatives: y_true == True and y_pred <= thresholds
  true_negatives: y_true == False and y_pred <= thresholds
  false_positive: y_true == False and y_pred > thresholds

  The results will be weighted and added together. When multiple thresholds are
  provided, we will repeat the same for every threshold.

  For estimation of these metrics over a stream of data, the function creates an
  `update_op` operation that updates the given variables.

  If `sample_weight` is `None`, weights default to 1.
  Use weights of 0 to mask values.

  Args:
    variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
      and corresponding variables to update as values.
    y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`.
    y_pred: A floating point `Tensor` of arbitrary shape and whose values are in
      the range `[0, 1]`.
    thresholds: A float value, float tensor, python list, or tuple of float
      thresholds in `[0, 1]`, or NEG_INF (used when top_k is set).
    top_k: Optional int, indicates that the positive labels should be limited to
      the top k predictions.
    class_id: Optional int, limits the prediction and labels to the class
      specified by this argument.
    sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
      `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions must
      be either `1`, or the same as the corresponding `y_true` dimension).
    multi_label: Optional boolean indicating whether multidimensional
      prediction/labels should be treated as multilabel responses, or flattened
      into a single label. When True, the valus of `variables_to_update` must
      have a second dimension equal to the number of labels in y_true and
      y_pred, and those tensors must not be RaggedTensors.
    label_weights: (optional) tensor of non-negative weights for multilabel
      data. The weights are applied when calculating TP, FP, FN, and TN without
      explicit multilabel handling (i.e. when the data is to be flattened).
    thresholds_distributed_evenly: Boolean, whether the thresholds are evenly
      distributed within the list. An optimized method will be used if this is
      the case. See _update_confusion_matrix_variables_optimized() for more
      details.

  Returns:
    Update op.

  Raises:
    ValueError: If `y_pred` and `y_true` have mismatched shapes, or if
      `sample_weight` is not `None` and its shape doesn't match `y_pred`, or if
      `variables_to_update` contains invalid keys.
  """
    if multi_label and label_weights is not None:
        raise ValueError(
            '`label_weights` for multilabel data should be handled '
            'outside of `update_confusion_matrix_variables` when '
            '`multi_label` is True.')
    if variables_to_update is None:
        return
    if not any(key
               for key in variables_to_update if key in list(ConfusionMatrix)):
        raise ValueError(
            'Please provide at least one valid confusion matrix '
            'variable to update. Valid variable key options are: '
            f'"{list(ConfusionMatrix)}". Received: "{variables_to_update.keys()}"'
        )

    variable_dtype = list(variables_to_update.values())[0].dtype

    y_true = tf.cast(y_true, dtype=variable_dtype)
    y_pred = tf.cast(y_pred, dtype=variable_dtype)

    if thresholds_distributed_evenly:
        # Check whether the thresholds has any leading or tailing epsilon added
        # for floating point imprecision. The leading and tailing threshold will be
        # handled bit differently as the corner case.
        # At this point, thresholds should be a list/array with more than 2 items,
        # and ranged between [0, 1]. See is_evenly_distributed_thresholds() for more
        # details.
        thresholds_with_epsilon = thresholds[0] < 0.0 or thresholds[-1] > 1.0

    thresholds = tf.convert_to_tensor(thresholds, dtype=variable_dtype)
    num_thresholds = thresholds.shape.as_list()[0]

    if multi_label:
        one_thresh = tf.equal(tf.cast(1, dtype=tf.int32),
                              tf.rank(thresholds),
                              name='one_set_of_thresholds_cond')
    else:
        [y_pred, y_true
         ], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true],
                                                             sample_weight)
        one_thresh = tf.cast(True, dtype=tf.bool)

    invalid_keys = [
        key for key in variables_to_update if key not in list(ConfusionMatrix)
    ]
    if invalid_keys:
        raise ValueError(
            f'Invalid keys: "{invalid_keys}". '
            f'Valid variable key options are: "{list(ConfusionMatrix)}"')

    if sample_weight is None:
        y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
            y_pred, y_true)
    else:
        sample_weight = tf.cast(sample_weight, dtype=variable_dtype)
        y_pred, y_true, sample_weight = (
            losses_utils.squeeze_or_expand_dimensions(
                y_pred, y_true, sample_weight=sample_weight))
    y_pred.shape.assert_is_compatible_with(y_true.shape)

    if top_k is not None:
        y_pred = _filter_top_k(y_pred, top_k)
    if class_id is not None:
        y_true = y_true[..., class_id]
        y_pred = y_pred[..., class_id]

    if thresholds_distributed_evenly:
        return _update_confusion_matrix_variables_optimized(
            variables_to_update,
            y_true,
            y_pred,
            thresholds,
            multi_label=multi_label,
            sample_weights=sample_weight,
            label_weights=label_weights,
            thresholds_with_epsilon=thresholds_with_epsilon)

    pred_shape = tf.shape(y_pred)
    num_predictions = pred_shape[0]
    if y_pred.shape.ndims == 1:
        num_labels = 1
    else:
        num_labels = tf.math.reduce_prod(pred_shape[1:], axis=0)
    thresh_label_tile = tf.where(one_thresh, num_labels,
                                 tf.ones([], dtype=tf.int32))

    # Reshape predictions and labels, adding a dim for thresholding.
    if multi_label:
        predictions_extra_dim = tf.expand_dims(y_pred, 0)
        labels_extra_dim = tf.expand_dims(tf.cast(y_true, dtype=tf.bool), 0)
    else:
        # Flatten predictions and labels when not multilabel.
        predictions_extra_dim = tf.reshape(y_pred, [1, -1])
        labels_extra_dim = tf.reshape(tf.cast(y_true, dtype=tf.bool), [1, -1])

    # Tile the thresholds for every prediction.
    if multi_label:
        thresh_pretile_shape = [num_thresholds, 1, -1]
        thresh_tiles = [1, num_predictions, thresh_label_tile]
        data_tiles = [num_thresholds, 1, 1]
    else:
        thresh_pretile_shape = [num_thresholds, -1]
        thresh_tiles = [1, num_predictions * num_labels]
        data_tiles = [num_thresholds, 1]

    thresh_tiled = tf.tile(tf.reshape(thresholds, thresh_pretile_shape),
                           tf.stack(thresh_tiles))

    # Tile the predictions for every threshold.
    preds_tiled = tf.tile(predictions_extra_dim, data_tiles)

    # Compare predictions and threshold.
    pred_is_pos = tf.greater(preds_tiled, thresh_tiled)

    # Tile labels by number of thresholds
    label_is_pos = tf.tile(labels_extra_dim, data_tiles)

    if sample_weight is not None:
        sample_weight = tf.__internal__.ops.broadcast_weights(
            tf.cast(sample_weight, dtype=variable_dtype), y_pred)
        weights_tiled = tf.tile(tf.reshape(sample_weight, thresh_tiles),
                                data_tiles)
    else:
        weights_tiled = None

    if label_weights is not None and not multi_label:
        label_weights = tf.expand_dims(label_weights, 0)
        label_weights = tf.__internal__.ops.broadcast_weights(
            label_weights, y_pred)
        label_weights_tiled = tf.tile(tf.reshape(label_weights, thresh_tiles),
                                      data_tiles)
        if weights_tiled is None:
            weights_tiled = label_weights_tiled
        else:
            weights_tiled = tf.multiply(weights_tiled, label_weights_tiled)

    update_ops = []

    def weighted_assign_add(label, pred, weights, var):
        label_and_pred = tf.cast(tf.logical_and(label, pred), dtype=var.dtype)
        if weights is not None:
            label_and_pred *= tf.cast(weights, dtype=var.dtype)
        return var.assign_add(tf.reduce_sum(label_and_pred, 1))

    loop_vars = {
        ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos),
    }
    update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
    update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update
    update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update

    if update_fn or update_tn:
        pred_is_neg = tf.logical_not(pred_is_pos)
        loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos,
                                                      pred_is_neg)

    if update_fp or update_tn:
        label_is_neg = tf.logical_not(label_is_pos)
        loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg,
                                                      pred_is_pos)
        if update_tn:
            loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (label_is_neg,
                                                         pred_is_neg)

    for matrix_cond, (label, pred) in loop_vars.items():

        if matrix_cond in variables_to_update:
            update_ops.append(
                weighted_assign_add(label, pred, weights_tiled,
                                    variables_to_update[matrix_cond]))

    return tf.group(update_ops)
def main(argv):
    if FLAGS.data_dir:
        if 'subsampled-tiny' in FLAGS.data_dir:
            n_data = 50000 // 16
        elif 'subsampled' in FLAGS.data_dir:
            n_data = 50000 // 4
    else:
        n_data = 50000
    steps_per_epoch = n_data // FLAGS.batch
    optimizer = tf.keras.optimizers.SGD(FLAGS.lr, momentum=0.9)

    trained_model = tf.keras.models.load_model(FLAGS.trained_model)
    layer_output = trained_model.layers[FLAGS.layer_idx].output
    out_dim = np.array(
        layer_output.get_shape().as_list()[1:])  # remove batch dimension
    if FLAGS.pooling:
        out_dim[0] /= 2
        out_dim[1] /= 2
    total_dim = np.prod(out_dim)
    train_dataset = load_linear_probe_train_data(trained_model,
                                                 FLAGS.layer_idx,
                                                 (total_dim, ),
                                                 FLAGS.batch,
                                                 data_path=FLAGS.data_sample)
    test_dataset = load_linear_probe_test_data(
        trained_model, FLAGS.layer_idx, (total_dim, ),
        FLAGS.batch)  #use full test dataset as validation

    #Define linear model
    inputs = tf.keras.Input(shape=(total_dim, ))
    outputs = tf.keras.layers.Dense(
        10,
        kernel_initializer='he_normal',
        kernel_regularizer=tf.keras.regularizers.l2(FLAGS.l2_reg))(inputs)
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    model.compile(
        optimizer,
        tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['acc'])
    save_dir = 'layer-%d-bs-%d-lr-%f-reg-%f' % \
        (FLAGS.layer_idx, FLAGS.batch, FLAGS.lr, FLAGS.l2_reg)
    if FLAGS.pooling:
        save_dir += '-pooling'
    experiment_dir = os.path.join(FLAGS.trained_model, save_dir)

    # Resume training in case of preemption
    optimizer_weights_set = True
    ckpt_path = os.path.join(experiment_dir, 'ckpt')
    opt_path = os.path.join(ckpt_path, 'optimizer_weights.pkl')
    metadata_path = os.path.join(ckpt_path, 'metadata.pkl')

    if not tf.io.gfile.exists(ckpt_path):
        tf.io.gfile.makedirs(ckpt_path)
    if tf.io.gfile.listdir(ckpt_path):
        opt_weights = pickle.load(tf.io.gfile.GFile(opt_path, 'rb'))
        optimizer_weights_set = False
        #optimizer.set_weights(opt_weights)
        model = tf.keras.models.load_model(ckpt_path)

    if tf.io.gfile.exists(metadata_path):
        metadata = pickle.load(tf.io.gfile.GFile(metadata_path, 'rb'))
        start_epoch = metadata['latest_epoch'] + 1
        best_val_acc = metadata['best_acc']
    else:
        best_val_acc = 0
        start_epoch = 0

    # Start training
    for epoch in range(start_epoch, FLAGS.num_epochs):
        for (batch_id, (images, labels)) in enumerate(train_dataset.take(-1)):
            if batch_id >= steps_per_epoch:
                continue

            with tf.GradientTape(persistent=True) as tape:
                logits = model(images, training=True)
                loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
                    from_logits=True)(labels, logits)

            grads = tape.gradient(loss_fn, model.trainable_variables)

            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            if not optimizer_weights_set:  # optimizer weights are only created during the first step
                optimizer.set_weights(opt_weights)
                optimizer_weights_set = True

        #Evaluate the model and print results
        n_correct_preds, n_val = 0, 10000
        for (_, (images, labels)) in enumerate(test_dataset.take(-1)):
            logits = model(images, training=False)
            correct_preds = tf.equal(tf.argmax(input=logits, axis=1), labels)
            n_correct_preds += correct_preds.numpy().sum()
        val_accuracy = n_correct_preds / n_val

        if val_accuracy > best_val_acc:
            best_val_acc = val_accuracy
            tf.keras.models.save_model(model,
                                       experiment_dir,
                                       overwrite=True,
                                       include_optimizer=False)

        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            tf.keras.models.save_model(model,
                                       ckpt_path,
                                       overwrite=True,
                                       include_optimizer=False)
            metadata = {'latest_epoch': epoch, 'best_acc': best_val_acc}
            pickle.dump(metadata, tf.io.gfile.GFile(metadata_path, 'wb'))
            opt_weights = optimizer.get_weights()
            pickle.dump(opt_weights, tf.io.gfile.GFile(opt_path, 'wb'))
示例#31
0
 def f(a1, a2):
     if a1.shape != a2.shape:
         return tf.constant(False)
     return tf.reduce_all(tf.equal(a1, a2))