Example #1
0
  def call(self, x, training=False):
    x_flat = tf.reshape(x, shape=(-1, self.depth))

    # Split each input vector into one segment per head.
    x_flat_split = tf.split(x_flat, self.num_heads, axis=1)
    x_flat = tf.concat(x_flat_split, axis=0)

    if training:
      # Figure out which centroids we want to keep, and which we want to
      # restart.
      n = x_flat.shape[0]
      keep = self.counts * self.k > self.restart_threshold * n
      restart = tf.math.logical_not(keep)

      # Replace centroids to restart with elements from the batch, using samples
      # from a uniform distribution as a fallback in case we need to restart
      # more centroids than we have elements in the batch.
      restart_idx = tf.squeeze(tf.where(restart), -1)
      n_replace = tf.minimum(tf.shape(restart_idx)[0], x_flat.shape[0])
      e_restart = tf.tensor_scatter_nd_update(
          tf.random.uniform([self.k, self.depth // self.num_heads]),
          tf.expand_dims(restart_idx[:n_replace], 1),
          tf.random.shuffle(x_flat)[:n_replace]
      )

      # Compute the values of the centroids we want to keep by dividing the
      # summed vectors by the corresponding counts.
      e = tf.where(
          tf.expand_dims(keep, 1),
          tf.math.divide_no_nan(self.sums, tf.expand_dims(self.counts, 1)),
          e_restart
      )

    else:
      # If not training, just use the centroids as is with no restarts.
      e = tf.math.divide_no_nan(self.sums, tf.expand_dims(self.counts, 1))

    # Compute distance between each input vector and each cluster center.
    distances = (
        tf.expand_dims(tf.reduce_sum(x_flat**2, axis=1), 1) -
        2 * tf.matmul(x_flat, tf.transpose(e)) +
        tf.expand_dims(tf.reduce_sum(e**2, axis=1), 0)
    )

    # Find nearest cluster center for each input vector.
    c = tf.argmin(distances, axis=1)

    # Quantize input vectors with straight-through estimator.
    z = tf.nn.embedding_lookup(e, c)
    z_split = tf.split(z, self.num_heads, axis=0)
    z = tf.concat(z_split, axis=1)
    z = tf.reshape(z, tf.shape(x))
    z = x + tf.stop_gradient(z - x)

    if training:
      # Compute cluster counts and vector sums over the batch.
      oh = tf.one_hot(indices=c, depth=self.k)
      counts = tf.reduce_sum(oh, axis=0)
      sums = tf.matmul(oh, x_flat, transpose_a=True)

      # Apply exponential moving average to cluster counts and vector sums.
      self.counts.assign_sub((1 - self.gamma) * (self.counts - counts))
      self.sums.assign_sub((1 - self.gamma) * (self.sums - sums))

    c_split = tf.split(c, self.num_heads, axis=0)
    c = tf.stack(c_split, axis=1)
    c = tf.reshape(c, tf.concat([tf.shape(x)[:-1], [self.num_heads]], axis=0))

    return z, c
 def _mode(self):
   logits = self._logits_parameter_no_checks()
   ret = tf.one_hot(
       tf.argmax(logits, axis=-1), self._event_size(logits), dtype=self.dtype)
   tensorshape_util.set_shape(ret, logits.shape)
   return ret
Example #3
0
def helper_test_keras_v2_gradienttape(script_mode: bool = False,
                                      json_file_contents="{}"):
    """ Test the default ZCC behavior of saving losses and metrics in eager and non-eager modes."""
    smd.del_hook()
    tf.keras.backend.clear_session()

    with SagemakerSimulator(json_file_contents=json_file_contents) as sim:
        model = tf.keras.models.Sequential([
            tf.keras.layers.Flatten(input_shape=(28, 28,
                                                 1)),  # WA for TF issue #36279
            tf.keras.layers.Dense(128, activation="relu"),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(10, activation="softmax"),
        ])
        (x_train, y_train), _ = get_keras_data()
        dataset = tf.data.Dataset.from_tensor_slices(
            (tf.cast(x_train[..., tf.newaxis] / 255,
                     tf.float32), tf.cast(y_train, tf.int64)))
        dataset = dataset.shuffle(1000).batch(64)

        opt = tf.keras.optimizers.RMSprop()
        cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
        train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
        n_epochs = 2
        if script_mode:
            if json_file_contents == "{}":
                hook = smd.KerasHook(out_dir=sim.out_dir,
                                     export_tensorboard=True)
            else:
                hook = smd.KerasHook.create_from_json_file()

            for epoch in range(n_epochs):
                print("Epoch %d/%d" % (epoch + 1, n_epochs))
                for data, labels in dataset:
                    dataset_labels = labels
                    labels = tf.one_hot(labels, depth=10)
                    with hook.wrap_tape(tf.GradientTape()) as tape:
                        logits = model(data, training=True)  # (32,10)
                        loss_value = cce(labels, logits)
                    grads = tape.gradient(loss_value, model.variables)
                    opt.apply_gradients(zip(grads, model.variables))
                    acc = train_acc_metric(dataset_labels, logits)
                    hook.record_tensor_value(tensor_name="accuracy",
                                             tensor_value=acc)
                log = "Epoch %d " % (epoch + 1)
                log += "Accuracy %.4f" % train_acc_metric.result()
                print(log)
                train_acc_metric.reset_states()
            hook = smd.get_hook()
            assert hook
            hook.close()
            # Check that hook created and tensors saved
            trial = smd.create_trial(path=sim.out_dir)
            assert len(trial.steps()) > 0, "Nothing saved at any step."
            assert len(trial.tensor_names()) > 0, "Tensors were not saved."
            assert len(trial.tensor_names(collection="losses")) > 0
        else:
            # ZCC support added from smdebug v0.8.0)
            for epoch in range(n_epochs):
                print("Epoch %d/%d" % (epoch + 1, n_epochs))
                for data, labels in dataset:
                    dataset_labels = labels
                    labels = tf.one_hot(labels, depth=10)
                    with tf.GradientTape(persistent=True) as tape:
                        logits = model(data, training=True)  # (32,10)
                        loss_value = cce(labels, logits)
                    grads = tape.gradient(loss_value, model.variables)
                    opt.apply_gradients(zip(grads, model.variables))
                    acc = train_acc_metric(dataset_labels, logits)
                log = "Epoch %d " % (epoch + 1)
                log += "Accuracy %.4f" % train_acc_metric.result()
                print(log)
                train_acc_metric.reset_states()
            hook = smd.get_hook()
            if not is_tf_2_2():
                assert not hook  # only supported on TF 2.2 and greater
                return
            assert hook
            hook.close()
            # Check that hook created and tensors saved
            trial = smd.create_trial(path=sim.out_dir)
            assert len(trial.steps()) > 0, "Nothing saved at any step."
            assert len(trial.tensor_names()) > 0, "Tensors were not saved."
            assert len(trial.tensor_names(collection="losses")) > 0
    def _sample_n(self, n, seed=None):
        with tf.control_dependencies(self._runtime_assertions):
            seed = seed_stream.SeedStream(seed, salt="HiddenMarkovModel")

            num_states = self._num_states

            batch_shape = self.batch_shape_tensor()
            batch_size = tf.reduce_prod(input_tensor=batch_shape)

            # The batch sizes of the underlying initial distributions and
            # transition distributions might not match the batch size of
            # the HMM distribution.
            # As a result we need to ask for more samples from the
            # underlying distributions and then reshape the results into
            # the correct batch size for the HMM.
            init_repeat = (
                tf.reduce_prod(input_tensor=self.batch_shape_tensor()) //
                tf.reduce_prod(input_tensor=self._initial_distribution.
                               batch_shape_tensor()))
            init_state = self._initial_distribution.sample(n * init_repeat,
                                                           seed=seed())
            init_state = tf.reshape(init_state, [n, batch_size])
            # init_state :: n batch_size

            transition_repeat = (
                tf.reduce_prod(input_tensor=self.batch_shape_tensor()) //
                tf.reduce_prod(input_tensor=self._transition_distribution.
                               batch_shape_tensor()[:-1]))

            def generate_step(state, _):
                """Take a single step in Markov chain."""

                gen = self._transition_distribution.sample(n *
                                                           transition_repeat,
                                                           seed=seed())
                # gen :: (n * transition_repeat) transition_batch

                new_states = tf.reshape(gen, [n, batch_size, num_states])

                # new_states :: n batch_size num_states

                old_states_one_hot = tf.one_hot(state,
                                                num_states,
                                                dtype=tf.int32)

                # old_states :: n batch_size num_states

                return tf.reduce_sum(input_tensor=old_states_one_hot *
                                     new_states,
                                     axis=-1)

            if self._num_steps > 1:
                dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)
                hidden_states = tf.scan(generate_step,
                                        dummy_index,
                                        initializer=init_state)

                # TODO(b/115618503): add/use prepend_initializer to tf.scan
                hidden_states = tf.concat([[init_state], hidden_states],
                                          axis=0)
            else:
                hidden_states = init_state[tf.newaxis, ...]

            # hidden_states :: num_steps n batch_size num_states

            hidden_one_hot = tf.one_hot(
                hidden_states,
                num_states,
                dtype=self._observation_distribution.dtype)
            # hidden_one_hot :: num_steps n batch_size num_states

            # The observation distribution batch size might not match
            # the required batch size so as with the initial and
            # transition distributions we generate more samples and
            # reshape.
            observation_repeat = (
                batch_size //
                tf.reduce_prod(input_tensor=self._observation_distribution.
                               batch_shape_tensor()[:-1]))

            possible_observations = self._observation_distribution.sample(
                [self._num_steps, observation_repeat * n])

            inner_shape = self._observation_distribution.event_shape

            # possible_observations :: num_steps (observation_repeat * n)
            #                          observation_batch[:-1] num_states inner_shape

            possible_observations = tf.reshape(
                possible_observations,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           inner_shape],
                          axis=0))

            # possible_observations :: steps n batch_size num_states inner_shape

            hidden_one_hot = tf.reshape(
                hidden_one_hot,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           tf.ones_like(inner_shape)],
                          axis=0))

            # hidden_one_hot :: steps n batch_size num_states "inner_shape"

            observations = tf.reduce_sum(input_tensor=hidden_one_hot *
                                         possible_observations,
                                         axis=-1 - tf.size(input=inner_shape))

            # observations :: steps n batch_size inner_shape

            observations = distribution_util.move_dimension(
                observations, 0, 1 + tf.size(input=batch_shape))

            # returned :: n batch_shape steps inner_shape

            return observations
Example #5
0
def helper_keras_gradtape(
    trial_dir,
    save_all=False,
    include_collections=None,
    reduction_config=None,
    save_config=None,
    hook=None,
    batch_size=64,
    persistent=False,
):
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), _ = mnist.load_data()
    dataset = tf.data.Dataset.from_tensor_slices(
        (tf.cast(x_train[..., tf.newaxis] / 255,
                 tf.float32), tf.cast(y_train, tf.int64)))
    dataset = dataset.shuffle(1000).batch(batch_size)

    model = tf.keras.models.Sequential([
        # WA for TF issue https://github.com/tensorflow/tensorflow/issues/36279
        tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation="softmax"),
    ])

    if hook is None:
        if save_config is None:
            save_config = SaveConfig(save_interval=3)

        hook = smd.KerasHook(
            trial_dir,
            save_config=save_config,
            save_all=save_all,
            include_collections=include_collections,
            reduction_config=reduction_config,
        )

        if not save_all and include_collections is not None:
            for cname in hook.include_collections:
                if cname not in include_collections:
                    hook.get_collection(cname).save_config = SaveConfig(
                        end_step=0)

    opt = tf.keras.optimizers.Adam()
    hook.wrap_optimizer(opt)

    cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()

    n_epochs = 1
    for epoch in range(n_epochs):
        for data, labels in dataset:
            dataset_labels = labels
            labels = tf.one_hot(labels, depth=10)
            with hook.wrap_tape(
                    tf.GradientTape(persistent=persistent)) as tape:
                logits = model(data, training=True)  # (32,10)
                loss_value = cce(labels, logits)
            grads = tape.gradient(loss_value, model.variables)

            # By default, the resources held by a GradientTape are released as
            # soon as GradientTape.gradient() method is called. To compute
            # multiple gradients over the same computation, create a persistent
            # gradient tape. This allows multiple calls to the gradient() method
            # as resources are released when the tape object is garbage collected.
            if persistent:
                _ = tape.gradient(loss_value, model.variables)
            opt.apply_gradients(zip(grads, model.variables))
            acc = train_acc_metric(dataset_labels, logits)
            hook.record_tensor_value(tensor_name="accuracy", tensor_value=acc)
        train_acc_metric.reset_states()

    hook.close()
Example #6
0
 def _one_hot_encoding_label(wav, label):
   return wav, tf.one_hot(label, num_classes)
    def train_step(self,
                   dataset: dataset_lib.OffpolicyDataset,
                   target_policy: tf_policy.TFPolicy,
                   regularizer: float = 1e-6):
        """Performs single iteration of CoinDICE.

    Args:
      dataset: The dataset to sample experience from.
      target_policy: The policy whose value we want to estimate.
      regularizer: A small constant to add to matrices before inverting them or
        to floats before taking square root.

    Returns:
      Estimated average per-step reward of the target policy.
    """
        # First compute Lagrangian loss.
        saddle_bellman_residuals = (tf.matmul(self._a_vec, self._nu) -
                                    self._weighted_rewards[:, None])
        saddle_bellman_residuals *= -1 * self._algae_alpha_sign
        saddle_zetas = tf.gather(self._zeta, self._nu_indices)
        saddle_initial_nu_values = tf.reduce_sum(  # Average over actions.
            self._initial_target_probs[:, :, None] *
            tf.gather(self._nu, self._initial_nu_indices),
            axis=1)
        saddle_init_nu_loss = ((1 - self._gamma) * saddle_initial_nu_values *
                               self._algae_alpha_sign)

        # This second optimization switches the sign of algae_alpha.
        # We add these two together to get the final loss, and thus counteract
        # the bias introduced by algae_alpha.
        saddle_bellman_residuals2 = (tf.matmul(self._a_vec, self._nu2) -
                                     self._weighted_rewards[:, None])
        saddle_bellman_residuals2 *= 1 * self._algae_alpha_sign
        saddle_zetas2 = tf.gather(self._zeta2, self._nu_indices)
        saddle_initial_nu_values2 = tf.reduce_sum(  # Average over actions.
            self._initial_target_probs[:, :, None] *
            tf.gather(self._nu2, self._initial_nu_indices),
            axis=1)
        saddle_init_nu_loss2 = ((1 - self._gamma) * saddle_initial_nu_values2 *
                                -1 * self._algae_alpha_sign)

        saddle_loss = 0.5 * (
            saddle_init_nu_loss + saddle_bellman_residuals * saddle_zetas +
            -tf.math.abs(self._algae_alpha) * 0.5 * tf.square(saddle_zetas) +
            -saddle_init_nu_loss2 + -saddle_bellman_residuals2 * saddle_zetas2
            + tf.math.abs(self._algae_alpha) * 0.5 * tf.square(saddle_zetas2))

        # Find optimal weights by doing binary search on alpha (lambda in the
        # paper).
        left = tf.constant([-8., -8.])
        right = tf.constant([32., 32.])
        for _ in range(16):
            mid = 0.5 * (left + right)
            self._alpha.assign(mid)
            weights, log_weights = self._get_weights(saddle_loss)

            divergence = self._compute_divergence(weights, log_weights)
            divergence_violation = divergence - self._two_sided_limit
            left = tf.where(divergence_violation > 0., mid, left)
            right = tf.where(divergence_violation > 0., right, mid)
        self._alpha.assign(0.5 * (left + right))
        weights, log_weights = self._get_weights(saddle_loss)

        # Now that we have weights, we reconstruct the Bellman residual matrices.
        data_weights = tf.stop_gradient(weights)
        avg_saddle_loss = (tf.reduce_sum(data_weights * saddle_loss, axis=0) /
                           tf.reduce_sum(data_weights, axis=0))

        weighted_state_action_count = tf.reduce_sum(
            tf.one_hot(self._nu_indices, self._dimension)[:, :, None] *
            weights[:, None, :],
            axis=0)
        weighted_state_action_count = tf.gather(weighted_state_action_count,
                                                self._nu_indices)
        my_td_mat = tf.einsum('ai, ab, ab, aj -> bij',
                              tf.one_hot(self._nu_indices, self._dimension),
                              1.0 / weighted_state_action_count, weights,
                              self._a_vec)
        my_bias = tf.reduce_sum(
            tf.transpose(weights)[:, :, None] *
            tf.one_hot(self._nu_indices, self._dimension)[None, :, :] *
            tf.reshape(self._weighted_rewards, [1, -1, 1]) * 1.0 /
            tf.transpose(weighted_state_action_count)[:, :, None],
            axis=1)

        # Solve for nu using primal form; i.e., E[(nu - B nu)^2] - (1-g) * E[nu0].
        with tf.GradientTape(watch_accessed_variables=False,
                             persistent=True) as tape:
            tape.watch([self._nu, self._nu2, self._alpha])
            bellman_residuals = tf.matmul(
                my_td_mat,
                tf.transpose(self._nu)[:, :, None]) - my_bias[:, :, None]
            bellman_residuals = tf.transpose(tf.squeeze(bellman_residuals, -1))
            bellman_residuals = tf.gather(bellman_residuals, self._nu_indices)
            initial_nu_values = tf.reduce_sum(  # Average over actions.
                self._initial_target_probs[:, :, None] *
                tf.gather(self._nu, self._initial_nu_indices),
                axis=1)

            bellman_residuals *= self._algae_alpha_sign

            init_nu_loss = ((1 - self._gamma) * initial_nu_values *
                            self._algae_alpha_sign)

            nu_loss = (tf.math.square(bellman_residuals) / 2.0 +
                       tf.math.abs(self._algae_alpha) * init_nu_loss)

            loss = (data_weights * nu_loss /
                    tf.reduce_sum(data_weights, axis=0, keepdims=True))

            bellman_residuals2 = tf.matmul(
                my_td_mat,
                tf.transpose(self._nu2)[:, :, None]) - my_bias[:, :, None]
            bellman_residuals2 = tf.transpose(
                tf.squeeze(bellman_residuals2, -1))
            bellman_residuals2 = tf.gather(bellman_residuals2,
                                           self._nu_indices)
            initial_nu_values2 = tf.reduce_sum(  # Average over actions.
                self._initial_target_probs[:, :, None] *
                tf.gather(self._nu2, self._initial_nu_indices),
                axis=1)

            bellman_residuals2 *= -1 * self._algae_alpha_sign

            init_nu_loss2 = ((1 - self._gamma) * initial_nu_values2 * -1 *
                             self._algae_alpha_sign)

            nu_loss2 = (tf.math.square(bellman_residuals2) / 2.0 +
                        tf.math.abs(self._algae_alpha) * init_nu_loss2)

            loss2 = (data_weights * nu_loss2 /
                     tf.reduce_sum(data_weights, axis=0, keepdims=True))

            divergence = self._compute_divergence(weights, log_weights)
            divergence_violation = divergence - self._two_sided_limit

            # Extra loss if for the 'terminal' state (index = -1).
            extra_loss = tf.reduce_sum(tf.math.square(self._nu[-1, :]))
            extra_loss2 = tf.reduce_sum(tf.math.square(self._nu2[-1, :]))

            nu_grad = tape.gradient(loss + extra_loss, [self._nu])[0]
            nu_grad2 = tape.gradient(loss2 + extra_loss2, [self._nu2])[0]

        avg_loss = tf.reduce_sum(0.5 * (loss - loss2) /
                                 tf.math.abs(self._algae_alpha),
                                 axis=0)
        nu_jacob = tape.jacobian(nu_grad, [self._nu])[0]
        nu_hess = tf.stack(
            [nu_jacob[:, i, :, i] for i in range(self._num_limits)], axis=0)

        nu_jacob2 = tape.jacobian(nu_grad2, [self._nu2])[0]
        nu_hess2 = tf.stack(
            [nu_jacob2[:, i, :, i] for i in range(self._num_limits)], axis=0)

        for idx, div in enumerate(divergence):
            tf.summary.scalar('divergence%d' % idx, div)

        # Perform Newton step on nu.
        nu_transformed = tf.transpose(
            tf.squeeze(
                tf.linalg.solve(
                    nu_hess + regularizer * tf.eye(self._dimension),
                    tf.expand_dims(-tf.transpose(nu_grad), axis=-1))))
        self._nu = self._nu + self._nu_learning_rate * nu_transformed
        nu_transformed2 = tf.transpose(
            tf.squeeze(
                tf.linalg.solve(
                    nu_hess2 + regularizer * tf.eye(self._dimension),
                    tf.expand_dims(-tf.transpose(nu_grad2), axis=-1))))
        self._nu2 = self._nu2 + self._nu_learning_rate * nu_transformed2

        # Perform step on zeta based on fact that zeta* = (nu* - bellman nu*)/a.
        zetas = tf.matmul(my_td_mat,
                          tf.transpose(self._nu)[:, :, None]) - my_bias[:, :,
                                                                        None]
        zetas = tf.transpose(tf.squeeze(zetas, -1))
        zetas *= -self._algae_alpha_sign
        zetas /= tf.math.abs(self._algae_alpha)
        self._zeta = self._zeta + self._zeta_learning_rate * (zetas -
                                                              self._zeta)

        zetas2 = tf.matmul(my_td_mat,
                           tf.transpose(self._nu2)[:, :, None]) - my_bias[:, :,
                                                                          None]
        zetas2 = tf.transpose(tf.squeeze(zetas2, -1))
        zetas2 *= 1 * self._algae_alpha_sign
        zetas2 /= tf.math.abs(self._algae_alpha)
        self._zeta2 = (self._zeta2 + self._zeta_learning_rate *
                       (zetas2 - self._zeta2))

        return [
            avg_saddle_loss * self._algae_alpha_sign,
            avg_loss * self._algae_alpha_sign, divergence
        ]
Example #8
0
    def _sample_n(self, n, seed=None):
        strm = SeedStream(seed, salt='HiddenMarkovModel')

        transition_batch_shape = self.transition_distribution.batch_shape_tensor(
        )
        num_states = transition_batch_shape[-1]

        batch_shape = self.batch_shape_tensor()
        batch_size = tf.reduce_prod(batch_shape)

        # The batch sizes of the underlying initial distributions and
        # transition distributions might not match the batch size of
        # the HMM distribution.
        # As a result we need to ask for more samples from the
        # underlying distributions and then reshape the results into
        # the correct batch size for the HMM.
        init_repeat = (
            tf.reduce_prod(batch_shape) //
            tf.reduce_prod(self._initial_distribution.batch_shape_tensor()))
        init_state = self._initial_distribution.sample(n * init_repeat,
                                                       seed=strm())
        init_state = tf.reshape(init_state, [n, batch_size])
        # init_state :: n batch_size

        transition_repeat = (tf.reduce_prod(batch_shape) //
                             tf.reduce_prod(transition_batch_shape[:-1]))

        init_shape = init_state.shape

        def generate_step(state, _):
            """Take a single step in Markov chain."""

            gen = self._transition_distribution.sample(n * transition_repeat,
                                                       seed=strm())
            # gen :: (n * transition_repeat) transition_batch

            new_states = tf.reshape(gen, [n, batch_size, num_states])

            # new_states :: n batch_size num_states

            old_states_one_hot = tf.one_hot(state, num_states, dtype=tf.int32)

            # old_states :: n batch_size num_states

            result = tf.reduce_sum(old_states_one_hot * new_states, axis=-1)
            # We know that `generate_step` must preserve the shape of the
            # tensor of states of each state. This is because
            # the transition matrix must be square. But TensorFlow might
            # not know this so we explicitly tell it that the result has the
            # same shape.
            result.set_shape(init_shape)
            return result

        def _scan_multiple_steps():
            """Take multiple steps with tf.scan."""
            dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)
            if seed is not None:
                # Force parallel_iterations to 1 to ensure reproducibility
                # b/139210489
                hidden_states = tf.scan(generate_step,
                                        dummy_index,
                                        initializer=init_state,
                                        parallel_iterations=1)
            else:
                # Invoke default parallel_iterations behavior
                hidden_states = tf.scan(generate_step,
                                        dummy_index,
                                        initializer=init_state)

            # TODO(b/115618503): add/use prepend_initializer to tf.scan
            return tf.concat([[init_state], hidden_states], axis=0)

        hidden_states = prefer_static.cond(self._num_steps > 1,
                                           _scan_multiple_steps,
                                           lambda: init_state[tf.newaxis, ...])

        hidden_one_hot = tf.one_hot(hidden_states,
                                    num_states,
                                    dtype=self._observation_distribution.dtype)
        # hidden_one_hot :: num_steps n batch_size num_states

        # The observation distribution batch size might not match
        # the required batch size so as with the initial and
        # transition distributions we generate more samples and
        # reshape.
        observation_repeat = (batch_size // tf.reduce_prod(
            self._observation_distribution.batch_shape_tensor()[:-1]))

        possible_observations = self._observation_distribution.sample(
            [self._num_steps, observation_repeat * n], seed=strm())

        inner_shape = self._observation_distribution.event_shape_tensor()

        # possible_observations :: num_steps (observation_repeat * n)
        #                          observation_batch[:-1] num_states inner_shape

        possible_observations = tf.reshape(
            possible_observations,
            tf.concat(
                [[self._num_steps, n], batch_shape, [num_states], inner_shape],
                axis=0))

        # possible_observations :: steps n batch_size num_states inner_shape

        hidden_one_hot = tf.reshape(
            hidden_one_hot,
            tf.concat([[self._num_steps, n], batch_shape, [num_states],
                       tf.ones_like(inner_shape)],
                      axis=0))

        # hidden_one_hot :: steps n batch_size num_states "inner_shape"

        observations = tf.reduce_sum(hidden_one_hot * possible_observations,
                                     axis=-1 - tf.size(inner_shape))

        # observations :: steps n batch_size inner_shape

        observations = distribution_util.move_dimension(
            observations, 0, 1 + tf.size(batch_shape))

        # returned :: n batch_shape steps inner_shape

        return observations
Example #9
0
 def format_example(imgs, labels):
     """Formats each training and test example to work with our model."""
     imgs = tf.reshape(imgs, [-1, 28 * 28])
     imgs = tf.cast(imgs, tf.float32) / 255.0
     labels = tf.one_hot(labels, depth=10, dtype=tf.float32)
     return imgs, labels
Example #10
0
    def testShardedState(self):

        if not JAX_MODE:
            self.skipTest('b/181800108')

        num_burnin_steps = 1000
        num_adaptation_steps = int(num_burnin_steps * 0.8)
        num_results = 500
        num_chains = 64
        step_size = 1e-2
        num_mala_steps = 100

        def trace_fn(_, pkr):
            return {
                'step_size':
                unnest.get_innermost(pkr, 'step_size'),
                'mean_trajectory_length':
                unnest.get_innermost(pkr, 'max_trajectory_length') / 2.,
                'principal_component':
                unnest.get_innermost(pkr, 'ema_principal_component'),
                'variance':
                unnest.get_innermost(pkr, 'ema_variance'),
                'num_leapfrog_steps':
                unnest.get_innermost(pkr, 'num_leapfrog_steps'),
            }

        init_x = ([
            self.shard_values(
                tf.zeros((distribute_test_lib.NUM_DEVICES, num_chains)))
        ] * 2)
        local_scale = self.shard_values(
            1. + tf.one_hot(0, distribute_test_lib.NUM_DEVICES))

        @tf.function(autograph=False)
        def run(init_x, local_scale):
            @tfp.experimental.distribute.JointDistributionCoroutine
            def model():
                yield tfd.Normal(0., 1.)
                yield tfp.experimental.distribute.Sharded(
                    tfd.Normal(0., local_scale),
                    shard_axis_name=self.axis_name)

            kernel = tfp.experimental.mcmc.SNAPERHamiltonianMonteCarlo(
                model.log_prob,
                step_size=step_size,
                num_adaptation_steps=num_adaptation_steps,
                num_mala_steps=num_mala_steps,
                experimental_shard_axis_names=list(
                    model.experimental_shard_axis_names),
            )
            kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
                kernel,
                num_adaptation_steps=num_adaptation_steps,
            )

            return tfp.mcmc.sample_chain(
                num_results=num_burnin_steps + num_results,
                num_burnin_steps=0,
                current_state=init_x,
                kernel=kernel,
                trace_fn=trace_fn,
                seed=test_util.test_seed(sampler_type='stateless'))

        _, trace = self.evaluate(
            self.per_replica_to_tensor(
                self.strategy_run(
                    run,
                    args=(init_x, local_scale),
                    axis_name=self.axis_name,
                )))

        self.assertAllClose(0.,
                            trace['principal_component'][0][0, -1],
                            atol=0.1)
        expected_local_principal_component = np.zeros(
            distribute_test_lib.NUM_DEVICES)
        expected_local_principal_component[0] = 1.
        self.assertAllClose(expected_local_principal_component,
                            trace['principal_component'][1][:, -1],
                            atol=0.1)

        self.assertAllClose(1., trace['variance'][0][0, -1], atol=0.1)
        expected_local_variance = np.ones(distribute_test_lib.NUM_DEVICES)
        expected_local_variance[0] = 4.
        self.assertAllClose(expected_local_variance,
                            trace['variance'][1][:, -1],
                            rtol=0.2)

        # Shard consistency.
        self.assertAllClose(trace['step_size'][0], trace['step_size'][1])
        self.assertAllClose(trace['mean_trajectory_length'][0],
                            trace['mean_trajectory_length'][1])
Example #11
0
    def _sample_channels(self,
                         component_logits,
                         locs,
                         scales,
                         coeffs=None,
                         seed=None):
        """Sample a single pixel-iteration and apply channel conditioning.

    Args:
      component_logits: 4D `Tensor` of logits for the Categorical distribution
        over Quantized Logistic mixture components. Dimensions are `[batch_size,
        height, width, num_logistic_mix]`.
      locs: 4D `Tensor` of location parameters for the Quantized Logistic
        mixture components. Dimensions are `[batch_size, height, width,
        num_logistic_mix, num_channels]`.
      scales: 4D `Tensor` of location parameters for the Quantized Logistic
        mixture components. Dimensions are `[batch_size, height, width,
        num_logistic_mix, num_channels]`.
      coeffs: 4D `Tensor` of coefficients for the linear dependence among color
        channels, or `None` if there is only one channel. Dimensions are
        `[batch_size, height, width, num_logistic_mix, num_coeffs]`, where
        `num_coeffs = num_channels * (num_channels - 1) // 2`.
      seed: PRNG seed; see `tfp.random.sanitize_seed` for details.

    Returns:
      samples: 4D `Tensor` of sampled image data with autoregression among
        channels. Dimensions are `[batch_size, height, width, num_channels]`.
    """
        num_channels = self.event_shape[-1]

        # sample mixture components once for the entire pixel
        component_dist = categorical.Categorical(logits=component_logits)
        mask = tf.one_hot(indices=component_dist.sample(seed=seed),
                          depth=self._num_logistic_mix)
        mask = tf.cast(mask[..., tf.newaxis], self.dtype)

        # apply mixture component mask and separate out RGB parameters
        masked_locs = tf.reduce_sum(locs * mask, axis=-2)
        loc_tensors = tf.split(masked_locs, num_channels, axis=-1)
        masked_scales = tf.reduce_sum(scales * mask, axis=-2)
        scale_tensors = tf.split(masked_scales, num_channels, axis=-1)

        if coeffs is not None:
            num_coeffs = num_channels * (num_channels - 1) // 2
            masked_coeffs = tf.reduce_sum(coeffs * mask, axis=-2)
            coef_tensors = tf.split(masked_coeffs, num_coeffs, axis=-1)

        channel_samples = []
        coef_count = 0
        for i in range(num_channels):
            loc = loc_tensors[i]
            for c in channel_samples:
                loc += c * coef_tensors[coef_count]
                coef_count += 1

            logistic_samp = logistic.Logistic(
                loc=loc, scale=scale_tensors[i]).sample(seed=seed)
            logistic_samp = tf.clip_by_value(logistic_samp, -1., 1.)
            channel_samples.append(logistic_samp)

        return tf.concat(channel_samples, axis=-1)
    def _sample_n(self, n, seed):
        components_seed, mix_seed = samplers.split_seed(
            seed, salt='MixtureSameFamily')
        try:
            seed_stream = SeedStream(seed, salt='MixtureSameFamily')
        except TypeError as e:  # Can happen for Tensor seeds.
            seed_stream = None
            seed_stream_err = e
        try:
            x = self.components_distribution.sample(  # [n, B, k, E]
                n, seed=components_seed)
            if seed_stream is not None:
                seed_stream()  # Advance even if unused.
        except TypeError as e:
            if ('Expected int for argument' not in str(e)
                    and TENSOR_SEED_MSG_PREFIX not in str(e)):
                raise
            if seed_stream is None:
                raise seed_stream_err
            msg = (
                'Falling back to stateful sampling for `components_distribution` '
                '{} of type `{}`. Please update to use `tf.random.stateless_*` '
                'RNGs. This fallback may be removed after 20-Aug-2020. {}')
            warnings.warn(
                msg.format(self.components_distribution.name,
                           type(self.components_distribution), str(e)))
            x = self.components_distribution.sample(  # [n, B, k, E]
                n, seed=seed_stream())

        event_shape = None
        event_ndims = tensorshape_util.rank(self.event_shape)
        if event_ndims is None:
            event_shape = self.components_distribution.event_shape_tensor()
            event_ndims = prefer_static.rank_from_shape(event_shape)
        event_ndims_static = tf.get_static_value(event_ndims)

        num_components = None
        if event_ndims_static is not None:
            num_components = tf.compat.dimension_value(
                x.shape[-1 - event_ndims_static])
        # We could also check if num_components can be computed statically from
        # self.mixture_distribution's logits or probs.
        if num_components is None:
            num_components = tf.shape(x)[-1 - event_ndims]

        # TODO(jvdillon): Consider using tf.gather (by way of index unrolling).
        npdt = dtype_util.as_numpy_dtype(x.dtype)
        try:
            mix_sample = self.mixture_distribution.sample(
                n, seed=mix_seed)  # [n, B] or [n]
        except TypeError as e:
            if ('Expected int for argument' not in str(e)
                    and TENSOR_SEED_MSG_PREFIX not in str(e)):
                raise
            if seed_stream is None:
                raise seed_stream_err
            msg = (
                'Falling back to stateful sampling for `mixture_distribution` '
                '{} of type `{}`. Please update to use `tf.random.stateless_*` '
                'RNGs. This fallback may be removed after 20-Aug-2020. ({})')
            warnings.warn(
                msg.format(self.mixture_distribution.name,
                           type(self.mixture_distribution), str(e)))
            mix_sample = self.mixture_distribution.sample(
                n, seed=seed_stream())  # [n, B] or [n]
        mask = tf.one_hot(
            indices=mix_sample,  # [n, B] or [n]
            depth=num_components,
            on_value=npdt(1),
            off_value=npdt(0))  # [n, B, k] or [n, k]

        # Pad `mask` to [n, B, k, [1]*e] or [n, [1]*b, k, [1]*e] .
        batch_ndims = prefer_static.rank(x) - event_ndims - 1
        mask_batch_ndims = prefer_static.rank(mask) - 1
        pad_ndims = batch_ndims - mask_batch_ndims
        mask_shape = prefer_static.shape(mask)
        mask = tf.reshape(
            mask,
            shape=prefer_static.concat([
                mask_shape[:-1],
                prefer_static.ones([pad_ndims], dtype=tf.int32),
                mask_shape[-1:],
                prefer_static.ones([event_ndims], dtype=tf.int32),
            ],
                                       axis=0))

        if x.dtype in [
                tf.bfloat16, tf.float16, tf.float32, tf.float64, tf.complex64,
                tf.complex128
        ]:
            masked = tf.math.multiply_no_nan(x, mask)
        else:
            masked = x * mask
        ret = tf.reduce_sum(masked, axis=-1 - event_ndims)  # [n, B, E]

        if self._reparameterize:
            if event_shape is None:
                event_shape = self.components_distribution.event_shape_tensor()
            ret = self._reparameterize_sample(ret, event_shape=event_shape)

        return ret
Example #13
0
 def convert(self, image, label):
     image = tf.cast(image, self.dtype)
     image = image / tf.cast(255.0, dtype=self.dtype)
     label = tf.one_hot(label, depth=self.output_size)
     return image, tf.cast(label, tf.int32)
Example #14
0
    def update(self,
               expert_dataset_iter,
               policy_dataset_iter,
               discount,
               replay_regularization=0.05,
               nu_reg=10.0):
        """A function that updates nu network.

    When replay regularization is non-zero, it learns
    (d_pi * (1 - replay_regularization) + d_rb * replay_regulazation) /
    (d_expert * (1 - replay_regularization) + d_rb * replay_regulazation)
    instead.

    Args:
      expert_dataset_iter: An tensorflow graph iteratable over expert data.
      policy_dataset_iter: An tensorflow graph iteratable over training policy
        data, used for regularization.
      discount: An MDP discount.
      replay_regularization: A fraction of samples to add from a replay buffer.
      nu_reg: A grad penalty regularization coefficient.
    """

        (expert_states, expert_actions,
         expert_next_states) = expert_dataset_iter.get_next()

        expert_initial_states = expert_states

        # rb_states, rb_actions, rb_next_states, _, _ = policy_dataset_iter.get_next(
        # )[0]

        with tf.GradientTape(watch_accessed_variables=False,
                             persistent=True) as tape:
            tape.watch(self.actor.variables)
            tape.watch(self.nu_net.variables)

            _, policy_next_actions, _ = self.actor(expert_next_states)
            # _, rb_next_actions, rb_log_prob = self.actor(rb_next_states)

            _, policy_initial_actions, _ = self.actor(expert_initial_states)

            # Inputs for the linear part of DualDICE loss.
            expert_init_inputs = tf.concat(
                [expert_initial_states, policy_initial_actions], 1)

            if not self.discrete:
                expert_inputs = tf.concat([expert_states, expert_actions], 1)
            else:
                mat = tf.one_hot(tf.cast(expert_actions, tf.int32),
                                 depth=self.action_dim,
                                 axis=-1)
                expert_inputs = tf.concat([expert_states, mat], 1)
            expert_next_inputs = tf.concat(
                [expert_next_states, policy_next_actions], 1)

            # rb_inputs = tf.concat([rb_states, rb_actions], 1)
            # rb_next_inputs = tf.concat([rb_next_states, rb_next_actions], 1)

            expert_nu_0 = self.nu_net(expert_init_inputs)
            expert_nu = self.nu_net(expert_inputs)
            expert_nu_next = self.nu_net(expert_next_inputs)

            # rb_nu = self.nu_net(rb_inputs)
            # rb_nu_next = self.nu_net(rb_next_inputs)

            expert_diff = expert_nu - discount * expert_nu_next
            # rb_diff = rb_nu - discount * rb_nu_next

            linear_loss_expert = tf.reduce_mean(expert_nu_0 * (1 - discount))

            # linear_loss_rb = tf.reduce_mean(rb_diff)

            rb_expert_diff = expert_diff  #tf.concat([expert_diff, rb_diff], 0)
            rb_expert_weights = tf.ones(expert_diff.shape)  #tf.concat([
            #     tf.ones(expert_diff.shape) * (1 - replay_regularization),
            #     tf.ones(rb_diff.shape) * replay_regularization
            # ], 0)

            rb_expert_weights /= tf.reduce_sum(rb_expert_weights)
            non_linear_loss = tf.reduce_sum(
                tf.stop_gradient(
                    weighted_softmax(rb_expert_diff, rb_expert_weights,
                                     axis=0)) * rb_expert_diff)

            linear_loss = (linear_loss_expert * (1 - replay_regularization) +
                           0)
            # linear_loss_rb * replay_regularization)

            loss = (non_linear_loss - linear_loss)

            alpha = tf.random.uniform(shape=(expert_inputs.shape[0], 1))

            # nu_inter = alpha * expert_inputs + (1 - alpha) * expert_init_inputs #rb_inputs
            # nu_next_inter = alpha * expert_next_inputs + (1 - alpha) * #rb_next_inputs

            # nu_inter = tf.concat([nu_inter, nu_next_inter], 0)
            nu_inter = alpha * expert_inputs + (1 - alpha) * tf.stop_gradient(
                tf.random.shuffle(expert_next_inputs))

            with tf.GradientTape(watch_accessed_variables=False) as tape2:
                tape2.watch(nu_inter)
                nu_output = self.nu_net(nu_inter)
            nu_grad = tape2.gradient(nu_output, [nu_inter])[0] + EPS
            nu_grad_penalty = tf.reduce_mean(
                tf.square(tf.norm(nu_grad, axis=-1, keepdims=True) - 1))

            nu_loss = loss + nu_grad_penalty * nu_reg
            pi_loss = -loss + keras_utils.orthogonal_regularization(
                self.actor.trunk)

        nu_grads = tape.gradient(nu_loss, self.nu_net.variables)
        pi_grads = tape.gradient(pi_loss, self.actor.variables)

        self.nu_optimizer.apply_gradients(zip(nu_grads, self.nu_net.variables))
        self.actor_optimizer.apply_gradients(
            zip(pi_grads, self.actor.variables))

        del tape

        self.avg_nu_expert(expert_nu)
        #self.avg_nu_rb(rb_nu)

        self.nu_reg_metric(nu_grad_penalty)
        self.avg_loss(loss)

        self.avg_actor_loss(pi_loss)
        #self.avg_actor_entropy(-rb_log_prob)

        if tf.equal(self.nu_optimizer.iterations % self.log_interval, 0):
            tf.summary.scalar('train dual dice/loss',
                              self.avg_loss.result(),
                              step=self.nu_optimizer.iterations)
            keras_utils.my_reset_states(self.avg_loss)

            tf.summary.scalar('train dual dice/nu expert',
                              self.avg_nu_expert.result(),
                              step=self.nu_optimizer.iterations)
            keras_utils.my_reset_states(self.avg_nu_expert)

            tf.summary.scalar('train dual dice/nu rb',
                              self.avg_nu_rb.result(),
                              step=self.nu_optimizer.iterations)
            keras_utils.my_reset_states(self.avg_nu_rb)

            tf.summary.scalar('train dual dice/nu reg',
                              self.nu_reg_metric.result(),
                              step=self.nu_optimizer.iterations)
            keras_utils.my_reset_states(self.nu_reg_metric)

        if tf.equal(self.actor_optimizer.iterations % self.log_interval, 0):
            tf.summary.scalar('train sac/actor_loss',
                              self.avg_actor_loss.result(),
                              step=self.actor_optimizer.iterations)
            keras_utils.my_reset_states(self.avg_actor_loss)

            tf.summary.scalar('train sac/actor entropy',
                              self.avg_actor_entropy.result(),
                              step=self.actor_optimizer.iterations)
            keras_utils.my_reset_states(self.avg_actor_entropy)
def boolean_mask(boxlist,
                 indicator,
                 fields=None,
                 scope=None,
                 use_static_shapes=False,
                 indicator_sum=None):
    """Select boxes from BoxList according to indicator and return new BoxList.

  `boolean_mask` returns the subset of boxes that are marked as "True" by the
  indicator tensor. By default, `boolean_mask` returns boxes corresponding to
  the input index list, as well as all additional fields stored in the boxlist
  (indexing into the first dimension).  However one can optionally only draw
  from a subset of fields.

  Args:
    boxlist: BoxList holding N boxes
    indicator: a rank-1 boolean tensor
    fields: (optional) list of fields to also gather from.  If None (default),
      all fields are gathered from.  Pass an empty fields list to only gather
      the box coordinates.
    scope: name scope.
    use_static_shapes: Whether to use an implementation with static shape
      gurantees.
    indicator_sum: An integer containing the sum of `indicator` vector. Only
      required if `use_static_shape` is True.

  Returns:
    subboxlist: a BoxList corresponding to the subset of the input BoxList
      specified by indicator
  Raises:
    ValueError: if `indicator` is not a rank-1 boolean tensor.
  """
    with tf.name_scope(scope, 'BooleanMask'):
        if indicator.shape.ndims != 1:
            raise ValueError('indicator should have rank 1')
        if indicator.dtype != tf.bool:
            raise ValueError('indicator should be a boolean tensor')
        if use_static_shapes:
            if not (indicator_sum and isinstance(indicator_sum, int)):
                raise ValueError('`indicator_sum` must be a of type int')
            selected_positions = tf.cast(indicator, dtype=tf.float32)
            indexed_positions = tf.cast(tf.multiply(
                tf.cumsum(selected_positions), selected_positions),
                                        dtype=tf.int32)
            one_hot_selector = tf.one_hot(indexed_positions - 1,
                                          indicator_sum,
                                          dtype=tf.float32)
            sampled_indices = tf.cast(tf.tensordot(tf.cast(tf.range(
                tf.shape(indicator)[0]),
                                                           dtype=tf.float32),
                                                   one_hot_selector,
                                                   axes=[0, 0]),
                                      dtype=tf.int32)
            return gather(boxlist, sampled_indices, use_static_shapes=True)
        else:
            subboxlist = box_list.BoxList(
                tf.boolean_mask(boxlist.get(), indicator))
            if fields is None:
                fields = boxlist.get_extra_fields()
            for field in fields:
                if not boxlist.has_field(field):
                    raise ValueError(
                        'boxlist must contain all specified fields')
                subfieldlist = tf.boolean_mask(boxlist.get_field(field),
                                               indicator)
                subboxlist.add_field(field, subfieldlist)
            return subboxlist
def _binary_crossover(population,
                      population_size,
                      mutants,
                      crossover_prob,
                      seed):
  """Performs recombination by binary crossover for the current population.

  Let v_i denote the i'th component of the member v and m_i the corresponding
  component of the mutant vector corresponding to v. Then the crossed over
  vector w_i is determined by setting w_i =
  (m_i with probability=crossover_prob else v_i). In addition, DE requires that
  at least one of the components is crossed over (otherwise we end
  up with no change). This is done by choosing on index say k randomly where
  a force crossover is performed (i.e. w_k = m_k). This is the scheme
  implemented in this function.

  Args:
    population: A Python list of `Tensor`s where each `Tensor` in the list
      must be of rank at least 1 and all the elements must have a common
      first dimension. The base population to cross over.
    population_size: A scalar integer `Tensor`. The number of elements in the
      population (i.e. size of the first dimension of any member of
      `population`).
    mutants: A Python list of `Tensor`s with the same structure as `population`.
      The mutated population.
    crossover_prob: A positive real scalar `Tensor` bounded above by 1.0. The
      probability of a crossover being performed for each axis.
    seed: `int` or None. The random seed for this `Op`. If `None`, no seed is
      applied.

  Returns:
    A list of `Tensor`s of the same structure, dtype and shape as `population`.
    The recombined population.
  """
  sizes = [tf.cast(tf.size(x), dtype=tf.float64) for x in population]
  seed_stream = tfp_util.SeedStream(seed, salt='binary_crossover')
  force_crossover_group = distributions.Categorical(sizes).sample(
      [population_size, 1], seed=seed_stream())
  recombinants = []
  for i, population_part in enumerate(population):
    pop_part_flat = tf.reshape(population_part, [population_size, -1])
    mutant_part_flat = tf.reshape(mutants[i], [population_size, -1])
    part_size = tf.size(population_part) // population_size
    force_crossovers = tf.one_hot(
        tf.random.uniform([population_size],
                          minval=0,
                          maxval=part_size,
                          dtype=tf.int32,
                          seed=seed_stream()),
        part_size,
        on_value=True,
        off_value=False,
        dtype=tf.bool)  # Tensor of shape [population_size, size]
    group_mask = tf.math.equal(force_crossover_group, i)
    force_crossovers &= group_mask
    do_binary_crossover = tf.random.uniform(
        [population_size, part_size],
        dtype=crossover_prob.dtype.base_dtype,
        seed=seed_stream()) < crossover_prob
    do_binary_crossover |= force_crossovers
    recombinant_flat = tf1.where(
        do_binary_crossover, mutant_part_flat, pop_part_flat)
    recombinant = tf.reshape(recombinant_flat, tf.shape(population_part))
    recombinants.append(recombinant)
  return recombinants
Example #17
0
    def _build_target_quantile_values_op(self):
        """Build an op used as a target for return values at given quantiles.

    Returns:
      An op calculating the target quantile return.
    """
        batch_size = tf.shape(self._replay.rewards)[0]

        # Calculate SIL modified rewards.
        replay_action_one_hot = tf.one_hot(self._replay.actions,
                                           self.num_actions,
                                           1.,
                                           0.,
                                           name='action_one_hot')
        replay_target_q = tf.reduce_max(self._replay_target_q_values,
                                        axis=1,
                                        name='replay_chosen_target_q')
        replay_target_q_al = tf.reduce_sum(replay_action_one_hot *
                                           self._replay_target_q_values,
                                           axis=1,
                                           name='replay_chosen_target_q_al')
        comp_value = tf.math.maximum(replay_target_q_al, self._replay.returns)

        if self._clip > 0.:
            sil_bonus = self._alpha * tf.clip_by_value(
                (comp_value - replay_target_q), -self._clip, self._clip)
        else:
            sil_bonus = self._alpha * (comp_value - replay_target_q)

        # Shape of rewards: (num_tau_prime_samples x batch_size) x 1.
        rewards = (self._replay.rewards + sil_bonus)[:, None]
        rewards = tf.tile(rewards, [self.num_tau_prime_samples, 1])

        is_terminal_multiplier = 1. - tf.cast(self._replay.terminals,
                                              tf.float32)
        # Incorporate terminal state to discount factor.
        # size of gamma_with_terminal: (num_tau_prime_samples x batch_size) x 1.
        gamma_with_terminal = self.cumulative_gamma * is_terminal_multiplier
        gamma_with_terminal = tf.tile(gamma_with_terminal[:, None],
                                      [self.num_tau_prime_samples, 1])

        # Get the indices of the maximum Q-value across the action dimension.
        # Shape of replay_next_qt_argmax: (num_tau_prime_samples x batch_size) x 1.

        replay_next_qt_argmax = tf.tile(self._replay_next_qt_argmax[:, None],
                                        [self.num_tau_prime_samples, 1])

        # Shape of batch_indices: (num_tau_prime_samples x batch_size) x 1.
        batch_indices = tf.cast(
            tf.range(self.num_tau_prime_samples * batch_size)[:, None],
            tf.int64)

        # Shape of batch_indexed_target_values:
        # (num_tau_prime_samples x batch_size) x 2.
        batch_indexed_target_values = tf.concat(
            [batch_indices, replay_next_qt_argmax], axis=1)

        # Shape of next_target_values: (num_tau_prime_samples x batch_size) x 1.
        target_quantile_values = tf.gather_nd(
            self._replay_net_target_quantile_values,
            batch_indexed_target_values)[:, None]

        return rewards + gamma_with_terminal * target_quantile_values
Example #18
0
def sample_chain(
    num_results,
    current_state,
    previous_kernel_results=None,
    kernel=None,
    num_burnin_steps=0,
    num_steps_between_results=0,
    trace_fn=lambda current_state, kernel_results: kernel_results,
    return_final_kernel_results=False,
    parallel_iterations=10,
    name=None,
):
    """Implements Markov chain Monte Carlo via repeated `TransitionKernel` steps.

  This function samples from an Markov chain at `current_state` and whose
  stationary distribution is governed by the supplied `TransitionKernel`
  instance (`kernel`).

  This function can sample from multiple chains, in parallel. (Whether or not
  there are multiple chains is dictated by the `kernel`.)

  The `current_state` can be represented as a single `Tensor` or a `list` of
  `Tensors` which collectively represent the current state.

  Since MCMC states are correlated, it is sometimes desirable to produce
  additional intermediate states, and then discard them, ending up with a set of
  states with decreased autocorrelation.  See [Owen (2017)][1]. Such "thinning"
  is made possible by setting `num_steps_between_results > 0`. The chain then
  takes `num_steps_between_results` extra steps between the steps that make it
  into the results. The extra steps are never materialized, and thus do not
  increase memory requirements.

  Warning: when setting a `seed` in the `kernel`, ensure that `sample_chain`'s
  `parallel_iterations=1`, otherwise results will not be reproducible.

  In addition to returning the chain state, this function supports tracing of
  auxiliary variables used by the kernel. The traced values are selected by
  specifying `trace_fn`. By default, all kernel results are traced but in the
  future the default will be changed to no results being traced, so plan
  accordingly. See below for some examples of this feature.

  Args:
    num_results: Integer number of Markov chain draws.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s).
    previous_kernel_results: A `Tensor` or a nested collection of `Tensor`s
      representing internal calculations made within the previous call to this
      function (or as returned by `bootstrap_results`).
    kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step
      of the Markov chain.
    num_burnin_steps: Integer number of chain steps to take before starting to
      collect results.
      Default value: 0 (i.e., no burn-in).
    num_steps_between_results: Integer number of chain steps between collecting
      a result. Only one out of every `num_steps_between_samples + 1` steps is
      included in the returned results.  The number of returned chain states is
      still equal to `num_results`.  Default value: 0 (i.e., no thinning).
    trace_fn: A callable that takes in the current chain state and the previous
      kernel results and return a `Tensor` or a nested collection of `Tensor`s
      that is then traced along with the chain state.
    return_final_kernel_results: If `True`, then the final kernel results are
      returned alongside the chain state and the trace specified by the
      `trace_fn`.
    parallel_iterations: The number of iterations allowed to run in parallel. It
      must be a positive integer. See `tf.while_loop` for more details.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., "mcmc_sample_chain").

  Returns:
    checkpointable_states_and_trace: if `return_final_kernel_results` is
      `True`. The return value is an instance of
      `CheckpointableStatesAndTrace`.
    all_states: if `return_final_kernel_results` is `False` and `trace_fn` is
      `None`. The return value is a `Tensor` or Python list of `Tensor`s
      representing the state(s) of the Markov chain(s) at each result step. Has
      same shape as input `current_state` but with a prepended
      `num_results`-size dimension.
    states_and_trace: if `return_final_kernel_results` is `False` and
      `trace_fn` is not `None`. The return value is an instance of
      `StatesAndTrace`.

  #### Examples

  ##### Sample from a diagonal-variance Gaussian.

  I.e.,

  ```none
  for i=1..n:
    x[i] ~ MultivariateNormal(loc=0, scale=diag(true_stddev))  # likelihood
  ```

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp
  tfd = tfp.distributions

  dims = 10
  true_stddev = np.sqrt(np.linspace(1., 3., dims))
  likelihood = tfd.MultivariateNormalDiag(loc=0., scale_diag=true_stddev)

  states = tfp.mcmc.sample_chain(
      num_results=1000,
      num_burnin_steps=500,
      current_state=tf.zeros(dims),
      kernel=tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=likelihood.log_prob,
        step_size=0.5,
        num_leapfrog_steps=2),
      trace_fn=None)

  sample_mean = tf.reduce_mean(states, axis=0)
  # ==> approx all zeros

  sample_stddev = tf.sqrt(tf.reduce_mean(
      tf.squared_difference(states, sample_mean),
      axis=0))
  # ==> approx equal true_stddev
  ```

  ##### Sampling from factor-analysis posteriors with known factors.

  I.e.,

  ```none
  # prior
  w ~ MultivariateNormal(loc=0, scale=eye(d))
  for i=1..n:
    # likelihood
    x[i] ~ Normal(loc=w^T F[i], scale=1)
  ```

  where `F` denotes factors.

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp
  tfd = tfp.distributions

  # Specify model.
  def make_prior(dims):
    return tfd.MultivariateNormalDiag(
        loc=tf.zeros(dims))

  def make_likelihood(weights, factors):
    return tfd.MultivariateNormalDiag(
        loc=tf.matmul(weights, factors, adjoint_b=True))

  def joint_log_prob(num_weights, factors, x, w):
    return (make_prior(num_weights).log_prob(w) +
            make_likelihood(w, factors).log_prob(x))

  def unnormalized_log_posterior(w):
    # Posterior is proportional to: `p(W, X=x | factors)`.
    return joint_log_prob(num_weights, factors, x, w)

  # Setup data.
  num_weights = 10 # == d
  num_factors = 40 # == n
  num_chains = 100

  weights = make_prior(num_weights).sample(1)
  factors = tf.random_normal([num_factors, num_weights])
  x = make_likelihood(weights, factors).sample()

  # Sample from Hamiltonian Monte Carlo Markov Chain.

  # Get `num_results` samples from `num_chains` independent chains.
  chains_states, kernels_results = tfp.mcmc.sample_chain(
      num_results=1000,
      num_burnin_steps=500,
      current_state=tf.zeros([num_chains, num_weights], name='init_weights'),
      kernel=tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=unnormalized_log_posterior,
        step_size=0.1,
        num_leapfrog_steps=2))

  # Compute sample stats.
  sample_mean = tf.reduce_mean(chains_states, axis=[0, 1])
  # ==> approx equal to weights

  sample_var = tf.reduce_mean(
      tf.squared_difference(chains_states, sample_mean),
      axis=[0, 1])
  # ==> less than 1
  ```

  ##### Custom tracing functions.

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp
  tfd = tfp.distributions

  likelihood = tfd.Normal(loc=0., scale=1.)

  def sample_chain(trace_fn):
    return tfp.mcmc.sample_chain(
      num_results=1000,
      num_burnin_steps=500,
      current_state=0.,
      kernel=tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=likelihood.log_prob,
        step_size=0.5,
        num_leapfrog_steps=2),
      trace_fn=trace_fn)

  def trace_log_accept_ratio(states, previous_kernel_results):
    return previous_kernel_results.log_accept_ratio

  def trace_everything(states, previous_kernel_results):
    return previous_kernel_results

  _, log_accept_ratio = sample_chain(trace_fn=trace_log_accept_ratio)
  _, kernel_results = sample_chain(trace_fn=trace_everything)

  acceptance_prob = tf.math.exp(tf.minimum(log_accept_ratio, 0.))
  # Equivalent to, but more efficient than:
  acceptance_prob = tf.math.exp(tf.minimum(
      kernel_results.log_accept_ratio, 0.))
  ```

  #### References

  [1]: Art B. Owen. Statistically efficient thinning of a Markov chain sampler.
       _Technical Report_, 2017.
       http://statweb.stanford.edu/~owen/reports/bestthinning.pdf
  """
    if not kernel.is_calibrated:
        warnings.warn(
            "supplied `TransitionKernel` is not calibrated. Markov "
            "chain may not converge to intended target distribution.")
    with tf.name_scope(name or "mcmc_sample_chain"):
        num_results = tf.convert_to_tensor(num_results,
                                           dtype=tf.int32,
                                           name="num_results")
        num_burnin_steps = tf.convert_to_tensor(num_burnin_steps,
                                                dtype=tf.int32,
                                                name="num_burnin_steps")
        num_steps_between_results = tf.convert_to_tensor(
            num_steps_between_results,
            dtype=tf.int32,
            name="num_steps_between_results")
        current_state = tf.nest.map_structure(
            lambda x: tf.convert_to_tensor(x, name="current_state"),
            current_state)
        if previous_kernel_results is None:
            previous_kernel_results = kernel.bootstrap_results(current_state)

        if trace_fn is None:
            # It simplifies the logic to use a dummy function here.
            trace_fn = lambda *args: ()
            no_trace = True
        else:
            no_trace = False
        if trace_fn is sample_chain.__defaults__[4]:
            warnings.warn(
                "Tracing all kernel results by default is deprecated. Set "
                "the `trace_fn` argument to None (the future default "
                "value) or an explicit callback that traces the values "
                "you are interested in.")

        def _trace_scan_fn(state_and_results, num_steps):
            next_state, current_kernel_results = mcmc_util.smart_for_loop(
                loop_num_iter=num_steps,
                body_fn=kernel.one_step,
                initial_loop_vars=list(state_and_results),
                parallel_iterations=parallel_iterations)
            return next_state, current_kernel_results

        (_, final_kernel_results), (all_states, trace) = mcmc_util.trace_scan(
            loop_fn=_trace_scan_fn,
            initial_state=(current_state, previous_kernel_results),
            elems=tf.one_hot(indices=0,
                             depth=num_results,
                             on_value=1 + num_burnin_steps,
                             off_value=1 + num_steps_between_results,
                             dtype=tf.int32),
            # pylint: disable=g-long-lambda
            trace_fn=lambda state_and_results:
            (state_and_results[0], trace_fn(*state_and_results)),
            # pylint: enable=g-long-lambda
            parallel_iterations=parallel_iterations)

        if return_final_kernel_results:
            return CheckpointableStatesAndTrace(
                all_states=all_states,
                trace=trace,
                final_kernel_results=final_kernel_results)
        else:
            if no_trace:
                return all_states
            else:
                return StatesAndTrace(all_states=all_states, trace=trace)
    def prepare_dataset(self, dataset: dataset_lib.OffpolicyDataset,
                        target_policy: tf_policy.TFPolicy):
        """Performs pre-computations on dataset to make solving easier."""
        episodes, valid_steps = dataset.get_all_episodes(
            limit=self._limit_episodes)
        total_num_steps_per_episode = tf.shape(valid_steps)[1] - 1
        num_episodes = tf.shape(valid_steps)[0]
        num_samples = num_episodes * total_num_steps_per_episode
        valid_and_not_last = tf.logical_and(valid_steps, episodes.discount > 0)
        valid_indices = tf.squeeze(
            tf.where(tf.reshape(valid_and_not_last[:, :-1], [-1])))

        # Flatten all tensors so that each data sample is a tuple of
        # (initial_env_step, env_step, next_env_step).
        initial_env_step = tf.nest.map_structure(
            lambda t: tf.squeeze(
                tf.reshape(
                    tf.repeat(t[:, 0:1, ...],
                              axis=1,
                              repeats=total_num_steps_per_episode),
                    [num_samples, -1])), episodes)
        initial_env_step = tf.nest.map_structure(
            lambda t: tf.gather(t, valid_indices), initial_env_step)
        tfagents_initial_env_step = dataset_lib.convert_to_tfagents_timestep(
            initial_env_step)

        env_step = tf.nest.map_structure(
            lambda t: tf.squeeze(
                tf.reshape(t[:, 0:total_num_steps_per_episode, ...],
                           [num_samples, -1])), episodes)
        env_step = tf.nest.map_structure(lambda t: tf.gather(t, valid_indices),
                                         env_step)
        tfagents_env_step = dataset_lib.convert_to_tfagents_timestep(env_step)

        next_env_step = tf.nest.map_structure(
            lambda t: tf.squeeze(
                tf.reshape(t[:, 1:total_num_steps_per_episode + 1, ...],
                           [num_samples, -1])), episodes)
        next_env_step = tf.nest.map_structure(
            lambda t: tf.gather(t, valid_indices), next_env_step)
        tfagents_next_env_step = dataset_lib.convert_to_tfagents_timestep(
            next_env_step)

        # Get target probabilities for initial and next steps.
        initial_target_probs = target_policy.distribution(
            tfagents_initial_env_step).action.probs_parameter()
        next_target_probs = target_policy.distribution(
            tfagents_next_env_step).action.probs_parameter()

        # Map states and actions to indices into tabular representation.
        initial_states = tf.tile(
            tf.reshape(initial_env_step.observation, [-1, 1]),
            [1, self._num_actions])
        initial_actions = tf.tile(
            tf.reshape(tf.range(self._num_actions), [1, -1]),
            [initial_env_step.observation.shape[0], 1])
        initial_nu_indices = self._get_index(initial_states, initial_actions)

        next_states = tf.tile(tf.reshape(next_env_step.observation, [-1, 1]),
                              [1, self._num_actions])
        next_actions = tf.tile(
            tf.reshape(tf.range(self._num_actions), [1, -1]),
            [next_env_step.observation.shape[0], 1])
        next_nu_indices = self._get_index(next_states, next_actions)
        next_nu_indices = tf.where(
            tf.expand_dims(next_env_step.is_absorbing(), -1),
            -1 * tf.ones_like(next_nu_indices), next_nu_indices)

        nu_indices = self._get_index(env_step.observation, env_step.action)

        target_log_probabilities = target_policy.distribution(
            tfagents_env_step).action.log_prob(env_step.action)
        if not self._solve_for_state_action_ratio:
            policy_ratio = tf.exp(target_log_probabilities -
                                  env_step.get_log_probability())
        else:
            policy_ratio = tf.ones([
                target_log_probabilities.shape[0],
            ])
        policy_ratios = tf.tile(tf.reshape(policy_ratio, [-1, 1]),
                                [1, self._num_actions])

        # Bellman residual matrix of size [n_data, n_dim].
        a_vec = tf.one_hot(nu_indices, self._dimension) - tf.reduce_sum(
            self._gamma *
            tf.expand_dims(next_target_probs * policy_ratios, axis=-1) *
            tf.one_hot(next_nu_indices, self._dimension),
            axis=1)

        state_action_count = self._get_state_action_counts(env_step)
        # Bellman residual matrix of size [n_dim, n_dim].
        td_mat = tf.einsum('ai, a, aj -> ij',
                           tf.one_hot(nu_indices, self._dimension),
                           1.0 / tf.cast(state_action_count, tf.float32),
                           a_vec)

        # Reward vector of size [n_data].
        weighted_rewards = policy_ratio * self._reward_fn(env_step)

        # Reward vector of size [n_dim].
        bias = tf.reduce_sum(tf.one_hot(nu_indices, self._dimension) *
                             tf.reshape(weighted_rewards, [-1, 1]) * 1.0 /
                             tf.cast(state_action_count, tf.float32)[:, None],
                             axis=0)

        # Initialize.
        self._nu = np.ones_like(self._nu) * bias[:, None]
        self._nu2 = np.ones_like(self._nu2) * bias[:, None]

        self._a_vec = a_vec
        self._td_mat = td_mat
        self._bias = bias
        self._weighted_rewards = weighted_rewards
        self._state_action_count = state_action_count
        self._nu_indices = nu_indices
        self._initial_nu_indices = initial_nu_indices
        self._initial_target_probs = initial_target_probs
Example #20
0
    def _sample_n(self, n, seed=None):
        if self._use_static_graph:
            with tf.control_dependencies(self._assertions):
                # This sampling approach is almost the same as the approach used by
                # `MixtureSameFamily`. The differences are due to having a list of
                # `Distribution` objects rather than a single object, and maintaining
                # random seed management that is consistent with the non-static code
                # path.
                samples = []
                cat_samples = self.cat.sample(n, seed=seed)
                stream = seed_stream.SeedStream(seed, salt="Mixture")

                for c in range(self.num_components):
                    samples.append(self.components[c].sample(n, seed=stream()))
                x = tf.stack(samples, -self._static_event_shape.ndims -
                             1)  # [n, B, k, E]
                npdt = x.dtype.as_numpy_dtype
                mask = tf.one_hot(
                    indices=cat_samples,  # [n, B]
                    depth=self._num_components,  # == k
                    on_value=np.ones([], dtype=npdt),
                    off_value=np.zeros([], dtype=npdt))  # [n, B, k]
                mask = distribution_util.pad_mixture_dimensions(
                    mask, self, self._cat,
                    self._static_event_shape.ndims)  # [n, B, k, [1]*e]
                return tf.reduce_sum(
                    input_tensor=x * mask,
                    axis=-1 - self._static_event_shape.ndims)  # [n, B, E]

        with tf.control_dependencies(self._assertions):
            n = tf.convert_to_tensor(value=n, name="n")
            static_n = tf.get_static_value(n)
            n = int(static_n) if static_n is not None else n
            cat_samples = self.cat.sample(n, seed=seed)

            static_samples_shape = cat_samples.shape
            if static_samples_shape.is_fully_defined():
                samples_shape = static_samples_shape.as_list()
                samples_size = static_samples_shape.num_elements()
            else:
                samples_shape = tf.shape(input=cat_samples)
                samples_size = tf.size(input=cat_samples)
            static_batch_shape = self.batch_shape
            if static_batch_shape.is_fully_defined():
                batch_shape = static_batch_shape.as_list()
                batch_size = static_batch_shape.num_elements()
            else:
                batch_shape = self.batch_shape_tensor()
                batch_size = tf.reduce_prod(input_tensor=batch_shape)
            static_event_shape = self.event_shape
            if static_event_shape.is_fully_defined():
                event_shape = np.array(static_event_shape.as_list(),
                                       dtype=np.int32)
            else:
                event_shape = self.event_shape_tensor()

            # Get indices into the raw cat sampling tensor. We will
            # need these to stitch sample values back out after sampling
            # within the component partitions.
            samples_raw_indices = tf.reshape(tf.range(0, samples_size),
                                             samples_shape)

            # Partition the raw indices so that we can use
            # dynamic_stitch later to reconstruct the samples from the
            # known partitions.
            partitioned_samples_indices = tf.dynamic_partition(
                data=samples_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)

            # Copy the batch indices n times, as we will need to know
            # these to pull out the appropriate rows within the
            # component partitions.
            batch_raw_indices = tf.reshape(
                tf.tile(tf.range(0, batch_size), [n]), samples_shape)

            # Explanation of the dynamic partitioning below:
            #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
            # Suppose partitions are:
            #     [1 1 0 0 1 1]
            # After partitioning, batch indices are cut as:
            #     [batch_indices[x] for x in 2, 3]
            #     [batch_indices[x] for x in 0, 1, 4, 5]
            # i.e.
            #     [1 1] and [0 0 0 0]
            # Now we sample n=2 from part 0 and n=4 from part 1.
            # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
            # and for part 1 we want samples from batch entries 0, 0, 0, 0
            #   (samples 0, 1, 2, 3).
            partitioned_batch_indices = tf.dynamic_partition(
                data=batch_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)
            samples_class = [None for _ in range(self.num_components)]

            stream = seed_stream.SeedStream(seed, salt="Mixture")

            for c in range(self.num_components):
                n_class = tf.size(input=partitioned_samples_indices[c])
                samples_class_c = self.components[c].sample(n_class,
                                                            seed=stream())

                # Pull out the correct batch entries from each index.
                # To do this, we may have to flatten the batch shape.

                # For sample s, batch element b of component c, we get the
                # partitioned batch indices from
                # partitioned_batch_indices[c]; and shift each element by
                # the sample index. The final lookup can be thought of as
                # a matrix gather along locations (s, b) in
                # samples_class_c where the n_class rows correspond to
                # samples within this component and the batch_size columns
                # correspond to batch elements within the component.
                #
                # Thus the lookup index is
                #   lookup[c, i] = batch_size * s[i] + b[c, i]
                # for i = 0 ... n_class[c] - 1.
                lookup_partitioned_batch_indices = (
                    batch_size * tf.range(n_class) +
                    partitioned_batch_indices[c])
                samples_class_c = tf.reshape(
                    samples_class_c,
                    tf.concat([[n_class * batch_size], event_shape], 0))
                samples_class_c = tf.gather(samples_class_c,
                                            lookup_partitioned_batch_indices,
                                            name="samples_class_c_gather")
                samples_class[c] = samples_class_c

            # Stitch back together the samples across the components.
            lhs_flat_ret = tf.dynamic_stitch(
                indices=partitioned_samples_indices, data=samples_class)
            # Reshape back to proper sample, batch, and event shape.
            ret = tf.reshape(
                lhs_flat_ret,
                tf.concat(
                    [samples_shape, self.event_shape_tensor()], 0))
            ret.set_shape(
                tf.TensorShape(static_samples_shape).concatenate(
                    self.event_shape))
            return ret
Example #21
0
 def _write_update_to_result():
   one_hot = tf.one_hot(ind, depth=size_along_axis)
   mask_shape = len(tensor.shape) * [1]
   mask_shape[axis] = size_along_axis
   mask = tf.reshape(one_hot > 0, mask_shape)
   return tf.where(mask, new_tensor, tensor)
                labels_one_hot=tf.constant(y_train_onehot),
                samples=tf.constant(samples),
                weights=tf.constant(weights),
                _lambda=tf.constant(_lambda))
            loss1, share_loss1, tau = loss1.numpy(), share_loss1.numpy(
            ), tau.numpy()

        save_tau.append(tau)
        save_loss1.append(loss1)
        save_share_loss1.append(share_loss1)

        black_box_probs = black_box(X_train, trainable=tf.constant(False))
        black_box_labels = np.argmax(black_box_probs.numpy(), axis=1)

        if not (config_params["weights"]):
            black_box_probs = tf.one_hot(black_box_labels,
                                         len(np.unique(y_train)))

        #Learning sTGMA
        #tf.print("----Begin----")
        for j in range(hyper_params["stgma_steps"]):
            #print("Iteration sTGMA: ", j)
            toc = time()
            #print(X_train.shape,black_box_labels.astype(np.int32).shape, responsibilities.shape, samples.shape, black_box_probs.shape,  tf.random.shuffle(tf.range(model.data_dim), seed=(2^step)*(2*j+1) ))
            loss, share_loss2 = train_step_sTGMA(
                data=tf.constant(X_train),
                labels=tf.constant(black_box_labels.astype(np.int32)),
                responsibilities=tf.constant(responsibilities),
                eta=tf.constant(eta, dtype=tf.float32),
                samples=tf.constant(samples),
                weights=tf.constant(black_box_probs),
                t_range=tf.random.shuffle(tf.range(model.data_dim),
Example #23
0
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            if FLAGS.version2 and FLAGS.ensemble_size > 1:
                images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1])
                if not (FLAGS.member_sampling or FLAGS.expected_probs):
                    labels = tf.tile(labels, [FLAGS.ensemble_size])

            if FLAGS.num_train_samples > 1:
                images = tf.tile(images, [FLAGS.num_train_samples, 1, 1, 1])

            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                probs = tf.nn.softmax(logits)
                # Diversity evaluation.
                if FLAGS.version2 and FLAGS.ensemble_size > 1:
                    per_probs = tf.reshape(
                        probs,
                        tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]],
                                  0))

                    diversity_results = ed.metrics.average_pairwise_diversity(
                        per_probs, FLAGS.ensemble_size)

                if FLAGS.num_train_samples > 1:
                    probs = tf.reshape(
                        probs,
                        tf.concat(
                            [[FLAGS.num_train_samples, -1], probs.shape[1:]],
                            0))
                    probs = tf.reduce_mean(probs, 0)

                if FLAGS.member_sampling and FLAGS.version2 and FLAGS.ensemble_size > 1:
                    idx = tf.random.uniform([],
                                            maxval=FLAGS.ensemble_size,
                                            dtype=tf.int64)
                    idx_one_hot = tf.expand_dims(
                        tf.one_hot(idx, FLAGS.ensemble_size,
                                   dtype=probs.dtype), 0)
                    probs_shape = probs.shape
                    probs = tf.reshape(probs, [FLAGS.ensemble_size, -1])
                    probs = tf.matmul(idx_one_hot, probs)
                    probs = tf.reshape(probs,
                                       tf.concat([[-1], probs_shape[1:]], 0))

                elif FLAGS.expected_probs and FLAGS.version2 and FLAGS.ensemble_size > 1:
                    probs = tf.reshape(
                        probs,
                        tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]],
                                  0))
                    probs = tf.reduce_mean(probs, 0)

                negative_log_likelihood = tf.reduce_mean(
                    tf.keras.losses.sparse_categorical_crossentropy(
                        labels, probs))

                filtered_variables = []
                for var in model.trainable_variables:
                    # Apply l2 on the slow weights and bias terms. This excludes BN
                    # parameters and fast weight approximate posterior/prior parameters,
                    # but pay caution to their naming scheme.
                    if 'kernel' in var.name or 'bias' in var.name:
                        filtered_variables.append(tf.reshape(var, (-1, )))

                l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss(
                    tf.concat(filtered_variables, axis=0))
                kl = sum(model.losses) / train_dataset_size
                kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype)
                kl_scale /= FLAGS.kl_annealing_steps
                kl_scale = tf.minimum(1., kl_scale)
                kl_loss = kl_scale * kl

                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                loss = negative_log_likelihood + l2_loss + kl_loss
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)

            # Separate learning rate implementation.
            grad_list = []
            if FLAGS.fast_weight_lr_multiplier != 1.0:
                grads_and_vars = list(zip(grads, model.trainable_variables))
                for vec, var in grads_and_vars:
                    # Apply different learning rate on the fast weight approximate
                    # posterior/prior parameters. This is excludes BN and slow weights,
                    # but pay caution to the naming scheme.
                    if ('batch_norm' not in var.name
                            and 'kernel' not in var.name):
                        grad_list.append(
                            (vec * FLAGS.fast_weight_lr_multiplier, var))
                    else:
                        grad_list.append((vec, var))
                optimizer.apply_gradients(grad_list)
            else:
                optimizer.apply_gradients(zip(grads,
                                              model.trainable_variables))

            metrics['train/ece'].update_state(labels, probs)
            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, probs)
            if FLAGS.version2 and FLAGS.ensemble_size > 1:
                for k, v in diversity_results.items():
                    training_diversity['train/' + k].update_state(v)
Example #24
0
 def _mode(self):
     ret = tf.argmax(input=self.logits, axis=self._batch_rank)
     ret = tf.one_hot(ret, self.event_size, dtype=self.dtype)
     tensorshape_util.set_shape(ret, self.logits.shape)
     return ret
Example #25
0
def softquantiles(x,
                  quantiles,
                  quantile_width=None,
                  axis=-1,
                  may_squeeze=True,
                  **kwargs):
    """Computes soft quantiles via optimal transport.

  This operator takes advantage of the fact that an exhaustive softsort is not
  required to recover a single quantile. Instead, one can transport all
  input values in x onto only 3 weighted values. Target weights are adjusted so
  that those values in x that are transported to the middle value in the target
  vector y correspond to those concentrating around the quantile of interest.

  This idea generalizes to more quantiles, interleaving small weights on the
  quantile indices and bigger weights in between, corresponding to the gap from
  one desired quantile to the next one.

  Args:
   x: Tensor<float> of any shape.
   quantiles: list<float> the quantiles to be returned. It can also be a single
     float.
   quantile_width: (float) mass given to the bucket supposed to attract points
     whose value concentrate around the desired quantile value. Bigger width
     means that we allow the soft quantile to be a mixture of more points
     further away from the quantile. If None, the width is set at 1/n where n is
     the number of values considered (the size along the 'axis').
   axis: (int) the axis along which to compute the quantile.
   may_squeeze: (bool) should we squeeze the output tensor in case of a single
     quantile.
   **kwargs: see SoftQuantilizer for possible extra parameters.

  Returns:
    A Tensor<float> similar to the input tensor, but the axis dimension is
    replaced by the number of quantiles specified in the quantiles list.
    Hence, if only a quantile is requested (quantiles is a float) only one value
    in that axis is returned. When several quantiles are requested, the tensor
    will have that many values in that axis.

  Raises:
    tf.errors.InvalidArgumentError when the quantiles and quantile width are not
    correct, namely quantiles are either not in sorted order or the
    quantile_width is too large.
  """
    if isinstance(quantiles, float):
        quantiles = [quantiles]
    quantiles = tf.constant(quantiles, tf.float32)

    # Preprocesses submitted quantiles to check that they satisfy elementary
    # constraints.
    valid_quantiles = tf.boolean_mask(
        quantiles, tf.logical_and(quantiles > 0.0, quantiles < 1.0))
    num_quantiles = tf.shape(valid_quantiles)[0]

    # Includes values on both ends of [0,1].
    extended_quantiles = tf.concat([[0.0], valid_quantiles, [1.0]], axis=0)

    # Builds filler_weights in between the target quantiles.
    filler_weights = extended_quantiles[1:] - extended_quantiles[:-1]
    if quantile_width is None:
        quantile_width = tf.reduce_min(
            tf.concat([
                filler_weights,
                [1.0 / tf.cast(tf.shape(x)[axis], dtype=x.dtype)]
            ],
                      axis=0))

    # Takes into account quantile_width in the definition of weights
    shift = -tf.ones(tf.shape(filler_weights), dtype=x.dtype)
    shift = shift + 0.5 * (tf.one_hot(0, num_quantiles + 1) +
                           tf.one_hot(num_quantiles, num_quantiles + 1))
    filler_weights = filler_weights + quantile_width * shift

    assert_op = tf.Assert(tf.reduce_all(filler_weights >= 0.0),
                          [filler_weights])
    with tf.control_dependencies([assert_op]):
        # Adds one more value to have tensors of the same shape to interleave them.
        quantile_weights = tf.ones(num_quantiles + 1) * quantile_width

        # Interleaves the filler_weights with the quantile weights.
        weights = tf.reshape(
            tf.stack([filler_weights, quantile_weights], axis=1), (-1, ))[:-1]

        # Sends only the positive weights to the softsort operator.
        positive_weights = tf.boolean_mask(weights, weights > 0.0)
        all_quantiles = softsort(x,
                                 direction='ASCENDING',
                                 axis=axis,
                                 target_weights=positive_weights,
                                 **kwargs)

        # Recovers the indices corresponding to the desired quantiles.
        odds = tf.math.floormod(tf.range(weights.shape[0], dtype=tf.float32),
                                2)
        positives = tf.cast(weights > 0.0, tf.float32)
        indices = tf.cast(tf.math.cumsum(positives) * odds, dtype=tf.int32)
        indices = tf.boolean_mask(indices, indices > 0) - 1
        result = tf.gather(all_quantiles, indices, axis=axis)

        # In the specific case where we want a single quantile, squeezes the
        # quantile dimension.
        can_squeeze = tf.equal(tf.shape(result)[axis], 1)
        if tf.math.logical_and(can_squeeze, may_squeeze):
            result = tf.squeeze(result, axis=axis)
        return result
Example #26
0
    def _sample_n(self, n, seed=None):
        seeds = samplers.split_seed(seed,
                                    n=self.num_components + 1,
                                    salt='Mixture')
        try:
            seed_stream = SeedStream(seed, salt='Mixture')
        except TypeError as e:  # Can happen for Tensor seed.
            seed_stream = None
            seed_stream_err = e
        if self._use_static_graph:
            # This sampling approach is almost the same as the approach used by
            # `MixtureSameFamily`. The differences are due to having a list of
            # `Distribution` objects rather than a single object, and maintaining
            # random seed management that is consistent with the non-static code
            # path.
            samples = []
            cat_samples = self.cat.sample(n, seed=seeds[0])

            for c in range(self.num_components):
                try:
                    samples.append(self.components[c].sample(n,
                                                             seed=seeds[c +
                                                                        1]))
                    if seed_stream is not None:
                        seed_stream()
                except TypeError as e:
                    if ('Expected int for argument' not in str(e)
                            and TENSOR_SEED_MSG_PREFIX not in str(e)):
                        raise
                    if seed_stream is None:
                        raise seed_stream_err
                    msg = (
                        'Falling back to stateful sampling for `components[{}]` {} of '
                        'type `{}`. Please update to use `tf.random.stateless_*` RNGs. '
                        'This fallback may be removed after 20-Aug-2020. ({})')
                    warnings.warn(
                        msg.format(c, self.components[c].name,
                                   type(self.components[c]), str(e)))
                    samples.append(self.components[c].sample(
                        n, seed=seed_stream()))
            stack_axis = -1 - tensorshape_util.rank(self._static_event_shape)
            x = tf.stack(samples, axis=stack_axis)  # [n, B, k, E]
            npdt = dtype_util.as_numpy_dtype(x.dtype)
            mask = tf.one_hot(
                indices=cat_samples,  # [n, B]
                depth=self._num_components,  # == k
                on_value=npdt(1),
                off_value=npdt(0))  # [n, B, k]
            mask = distribution_util.pad_mixture_dimensions(
                mask, self, self._cat,
                tensorshape_util.rank(
                    self._static_event_shape))  # [n, B, k, [1]*e]
            return tf.reduce_sum(x * mask, axis=stack_axis)  # [n, B, E]

        n = tf.convert_to_tensor(n, name='n')
        static_n = tf.get_static_value(n)
        n = int(static_n) if static_n is not None else n
        cat_samples = self.cat.sample(n, seed=seeds[0])

        static_samples_shape = cat_samples.shape
        if tensorshape_util.is_fully_defined(static_samples_shape):
            samples_shape = tensorshape_util.as_list(static_samples_shape)
            samples_size = tensorshape_util.num_elements(static_samples_shape)
        else:
            samples_shape = tf.shape(cat_samples)
            samples_size = tf.size(cat_samples)
        static_batch_shape = self.batch_shape
        if tensorshape_util.is_fully_defined(static_batch_shape):
            batch_shape = tensorshape_util.as_list(static_batch_shape)
            batch_size = tensorshape_util.num_elements(static_batch_shape)
        else:
            batch_shape = tf.shape(cat_samples)[1:]
            batch_size = tf.reduce_prod(batch_shape)
        static_event_shape = self.event_shape
        if tensorshape_util.is_fully_defined(static_event_shape):
            event_shape = np.array(
                tensorshape_util.as_list(static_event_shape), dtype=np.int32)
        else:
            event_shape = None

        # Get indices into the raw cat sampling tensor. We will
        # need these to stitch sample values back out after sampling
        # within the component partitions.
        samples_raw_indices = tf.reshape(tf.range(0, samples_size),
                                         samples_shape)

        # Partition the raw indices so that we can use
        # dynamic_stitch later to reconstruct the samples from the
        # known partitions.
        partitioned_samples_indices = tf.dynamic_partition(
            data=samples_raw_indices,
            partitions=cat_samples,
            num_partitions=self.num_components)

        # Copy the batch indices n times, as we will need to know
        # these to pull out the appropriate rows within the
        # component partitions.
        batch_raw_indices = tf.reshape(tf.tile(tf.range(0, batch_size), [n]),
                                       samples_shape)

        # Explanation of the dynamic partitioning below:
        #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
        # Suppose partitions are:
        #     [1 1 0 0 1 1]
        # After partitioning, batch indices are cut as:
        #     [batch_indices[x] for x in 2, 3]
        #     [batch_indices[x] for x in 0, 1, 4, 5]
        # i.e.
        #     [1 1] and [0 0 0 0]
        # Now we sample n=2 from part 0 and n=4 from part 1.
        # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
        # and for part 1 we want samples from batch entries 0, 0, 0, 0
        #   (samples 0, 1, 2, 3).
        partitioned_batch_indices = tf.dynamic_partition(
            data=batch_raw_indices,
            partitions=cat_samples,
            num_partitions=self.num_components)
        samples_class = [None for _ in range(self.num_components)]

        for c in range(self.num_components):
            n_class = tf.size(partitioned_samples_indices[c])
            try:
                samples_class_c = self.components[c].sample(n_class,
                                                            seed=seeds[c + 1])
                if seed_stream is not None:
                    seed_stream()
            except TypeError as e:
                if ('Expected int for argument' not in str(e)
                        and TENSOR_SEED_MSG_PREFIX not in str(e)):
                    raise
                if seed_stream is None:
                    raise seed_stream_err
                msg = (
                    'Falling back to stateful sampling for `components[{}]` {} of '
                    'type `{}`. Please update to use `tf.random.stateless_*` RNGs. '
                    'This fallback may be removed after 20-Aug-2020. ({})')
                warnings.warn(
                    msg.format(c, self.components[c].name,
                               type(self.components[c]), str(e)))
                samples_class_c = self.components[c].sample(n_class,
                                                            seed=seed_stream())

            if event_shape is None:
                batch_ndims = prefer_static.rank_from_shape(batch_shape)
                event_shape = tf.shape(samples_class_c)[1 + batch_ndims:]

            # Pull out the correct batch entries from each index.
            # To do this, we may have to flatten the batch shape.

            # For sample s, batch element b of component c, we get the
            # partitioned batch indices from
            # partitioned_batch_indices[c]; and shift each element by
            # the sample index. The final lookup can be thought of as
            # a matrix gather along locations (s, b) in
            # samples_class_c where the n_class rows correspond to
            # samples within this component and the batch_size columns
            # correspond to batch elements within the component.
            #
            # Thus the lookup index is
            #   lookup[c, i] = batch_size * s[i] + b[c, i]
            # for i = 0 ... n_class[c] - 1.
            lookup_partitioned_batch_indices = (
                batch_size * tf.range(n_class) + partitioned_batch_indices[c])
            samples_class_c = tf.reshape(
                samples_class_c,
                tf.concat([[n_class * batch_size], event_shape], 0))
            samples_class_c = tf.gather(samples_class_c,
                                        lookup_partitioned_batch_indices,
                                        name='samples_class_c_gather')
            samples_class[c] = samples_class_c

        # Stitch back together the samples across the components.
        lhs_flat_ret = tf.dynamic_stitch(indices=partitioned_samples_indices,
                                         data=samples_class)
        # Reshape back to proper sample, batch, and event shape.
        ret = tf.reshape(lhs_flat_ret,
                         tf.concat([samples_shape, event_shape], 0))
        tensorshape_util.set_shape(
            ret,
            tensorshape_util.concatenate(static_samples_shape,
                                         self.event_shape))
        return ret
Example #27
0
def reduce_audio_in_batch(tensor, hparams=None, is_training=True):
    instrument_count = hparams.timbre_training_max_instruments
    note_croppping_list = []
    instrument_family_list = []
    samples_list = []
    max_length = 0
    for i in range(instrument_count):
        pitch = tensor['pitch'][i]
        # Move the audio so there are different attack times.
        start_idx = tf.random.uniform((),
                                      minval=0,
                                      maxval=hparams.timbre_max_start_offset,
                                      dtype='int64')
        samples = K.concatenate(
            [tf.zeros(start_idx),
             tf.sparse.to_dense(tensor['audio'])[i]])

        end_idx = (
            start_idx +
            tf.py_function(_get_approx_note_length,
                           [tf.sparse.to_dense(tensor['audio'])[i]], tf.int64))
        if hparams.timbre_max_len and end_idx > hparams.timbre_max_len:
            samples = tf.slice(samples,
                               begin=[0],
                               size=[hparams.timbre_max_len])
            end_idx = hparams.timbre_max_len
        if len(samples) > max_length:
            max_length = len(samples)

        samples_list.append(samples)

        instrument_family = tensor['instrument_family'][i]
        note_croppping_list.append(
            timbre_dataset_util.NoteCropping(pitch=pitch,
                                             start_idx=start_idx,
                                             end_idx=end_idx))
        instrument_family_list.append(
            tf.one_hot(tf.cast(instrument_family, tf.int32),
                       hparams.timbre_num_classes))

    # Pad the end of the shorter audio clips.
    samples_list = list(
        map(lambda x: tf.pad(x, [[0, max_length - len(x)]]), samples_list))

    combined_samples = (
        tf.reduce_sum(tf.convert_to_tensor(samples_list), axis=0) /
        instrument_count)

    # Ensure all audios in batches are the same length.
    if hparams.timbre_max_len:
        pad_length = hparams.timbre_max_len
    else:
        pad_length = hparams.timbre_max_start_offset + 5 * hparams.sample_rate
    combined_samples = tf.pad(
        combined_samples, [[0, pad_length - tf.shape(combined_samples)[0]]])
    note_croppings = tf.convert_to_tensor(note_croppping_list, dtype=tf.int32)
    instrument_families = tf.convert_to_tensor(instrument_family_list,
                                               dtype=tf.int32)

    wav_data = tf.py_function(
        lambda x: audio_io.samples_to_wav_data(
            x.numpy(), sample_rate=hparams.sample_rate), [combined_samples],
        tf.string)

    return dict(
        audio=wav_data,
        note_croppings=note_croppings,
        instrument_families=instrument_families,
    )
def interpolate(x_values,
                spline_data,
                optimize_for_tpu=False,
                dtype=None,
                name=None):
    """Interpolates spline values for the given `x_values` and the `spline_data`.

  Constant extrapolation is performed for the values outside the domain
  `spline_data.x_data`. This means that for `x > max(spline_data.x_data)`,
  `interpolate(x, spline_data) = spline_data.y_data[-1]`
  and for  `x < min(spline_data.x_data)`,
  `interpolate(x, spline_data) = spline_data.y_data[0]`.

  For the interpolation formula refer to p.548 of [1].

  #### References:
  [1]: R. Sedgewick, Algorithms in C, 1990, p. 545-550.
    Link: http://index-of.co.uk/Algorithms/Algorithms%20in%20C.pdf

  Args:
    x_values: A real `Tensor` of shape `batch_shape + [num_points]`.
    spline_data: An instance of `SplineParameters`. `spline_data.x_data` should
      have the same batch shape as `x_values`.
    optimize_for_tpu: A Python bool. If `True`, the algorithm uses one-hot
      encoding to lookup indices of `x_values` in `spline_data.x_data`. This
      significantly improves performance of the algorithm on a TPU device but
      may slow down performance on the CPU.
      Default value: `False`.
    dtype: Optional dtype for `x_values`.
      Default value: `None` which maps to the default dtype inferred by
      TensorFlow.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` which is mapped to the default name
      `cubic_spline_interpolate`.

  Returns:
      A `Tensor` of the same shape and `dtype` as `x_values`. Represents
      the interpolated values.

  Raises:
    ValueError:
      If `x_values` batch shape is different from `spline_data.x_data` batch
      shape.
  """
    name = name or "cubic_spline_interpolate"
    with tf.name_scope(name):
        x_values = tf.convert_to_tensor(x_values, dtype=dtype, name="x_values")
        dtype = x_values.dtype
        # Unpack the spline data
        x_data = spline_data.x_data
        y_data = spline_data.y_data
        spline_coeffs = spline_data.spline_coeffs
        rank = max(x_data.shape.rank, x_values.shape.rank)
        x_data = _expand_to_rank(x_data, rank)
        y_data = _expand_to_rank(y_data, rank)
        x_values = _expand_to_rank(x_values, rank)
        spline_coeffs = _expand_to_rank(spline_coeffs, rank)
        # Try broadcast batch_shapes
        if x_values.shape.as_list()[:-1] != x_data.shape.as_list()[:-1]:
            try:
                x_values = _broadcast_batch_shape(x_values, x_data.shape[:-1])
            except (tf.errors.InvalidArgumentError, ValueError):
                try:
                    x_data = _broadcast_batch_shape(x_data,
                                                    x_values.shape[:-1])
                    y_data = _broadcast_batch_shape(y_data,
                                                    x_values.shape[:-1])
                    spline_coeffs = _broadcast_batch_shape(
                        spline_coeffs, x_values.shape[:-1])
                except (tf.errors.InvalidArgumentError, ValueError):
                    msg = ("Can not broadcast batch shapes {} and {}")
                    raise ValueError(
                        msg.format(x_values.shape.as_list()[:-1],
                                   x_data.shape.as_list()[:-1]))
        # Determine the splines to use.
        indices = tf.searchsorted(x_data, x_values, side="right") - 1
        # This selects all elements for the start of the spline interval.
        # Make sure indices lie in the permissible range
        indices_lower = tf.maximum(indices, 0)
        # This selects all elements for the end of the spline interval.
        # Make sure indices lie in the permissible range
        indices_upper = tf.minimum(indices + 1, x_data.shape.as_list()[-1] - 1)
        # Prepare indices for `tf.gather_nd` or `tf.one_hot`
        # TODO(b/156720909): Extract get_slice logic into a common utilities module
        # for cubic and linear interpolation
        if optimize_for_tpu:
            x_data_size = x_data.shape.as_list()[-1]
            lower_encoding = tf.one_hot(indices_lower,
                                        x_data_size,
                                        dtype=dtype)
            upper_encoding = tf.one_hot(indices_upper,
                                        x_data_size,
                                        dtype=dtype)
        else:
            index_matrix = _prepare_indices(indices)
            lower_encoding = tf.concat(
                [index_matrix, tf.expand_dims(indices_lower, -1)], -1)
            upper_encoding = tf.concat(
                [index_matrix, tf.expand_dims(indices_upper, -1)], -1)

        # Calculate dx and dy.
        # Simplified logic:
        # dx = x_data[indices + 1] - x_data[indices]
        # dy = y_data[indices + 1] - y_data[indices]
        # indices is a tensor with different values per row/spline
        # Hence use a selection matrix with gather_nd
        def get_slice(x, encoding):
            if optimize_for_tpu:
                return tf.math.reduce_sum(tf.expand_dims(x, axis=-2) *
                                          encoding,
                                          axis=-1)
            else:
                return tf.gather_nd(x, encoding)

        x0 = get_slice(x_data, lower_encoding)
        x1 = get_slice(x_data, upper_encoding)
        dx = x1 - x0

        y0 = get_slice(y_data, lower_encoding)
        y1 = get_slice(y_data, upper_encoding)
        dy = y1 - y0

        spline_coeffs0 = get_slice(spline_coeffs, lower_encoding)
        spline_coeffs1 = get_slice(spline_coeffs, upper_encoding)

        t = (x_values - x0) / dx
        t = tf.where(dx > 0, t, tf.zeros_like(t))
        df = ((t + 1.0) * spline_coeffs1 * 2.0) - (
            (t - 2.0) * spline_coeffs0 * 2.0)
        df1 = df * t * (t - 1) / 6.0
        result = y0 + (t * dy) + (dx * dx * df1)
        # Use constant extrapolation outside the domain
        upper_bound = tf.expand_dims(tf.reduce_max(x_data, -1),
                                     -1) + tf.zeros_like(result)
        lower_bound = tf.expand_dims(tf.reduce_min(x_data, -1),
                                     -1) + tf.zeros_like(result)
        result = tf.where(
            tf.logical_and(x_values <= upper_bound, x_values >= lower_bound),
            result, tf.where(x_values > upper_bound, y0, y1))
        return result
Example #29
0
 def convert_to_one_hot(self, samples):
   return tf.one_hot(
       tf.argmax(samples, axis=-1),
       self.distribution.event_size, dtype=self._output_dtype)
 def map_fn(image, label):
     image = preprocess_fn_finetune(image)
     label = tf.one_hot(label, num_classes)
     return image, label