def call(self, x, training=False): x_flat = tf.reshape(x, shape=(-1, self.depth)) # Split each input vector into one segment per head. x_flat_split = tf.split(x_flat, self.num_heads, axis=1) x_flat = tf.concat(x_flat_split, axis=0) if training: # Figure out which centroids we want to keep, and which we want to # restart. n = x_flat.shape[0] keep = self.counts * self.k > self.restart_threshold * n restart = tf.math.logical_not(keep) # Replace centroids to restart with elements from the batch, using samples # from a uniform distribution as a fallback in case we need to restart # more centroids than we have elements in the batch. restart_idx = tf.squeeze(tf.where(restart), -1) n_replace = tf.minimum(tf.shape(restart_idx)[0], x_flat.shape[0]) e_restart = tf.tensor_scatter_nd_update( tf.random.uniform([self.k, self.depth // self.num_heads]), tf.expand_dims(restart_idx[:n_replace], 1), tf.random.shuffle(x_flat)[:n_replace] ) # Compute the values of the centroids we want to keep by dividing the # summed vectors by the corresponding counts. e = tf.where( tf.expand_dims(keep, 1), tf.math.divide_no_nan(self.sums, tf.expand_dims(self.counts, 1)), e_restart ) else: # If not training, just use the centroids as is with no restarts. e = tf.math.divide_no_nan(self.sums, tf.expand_dims(self.counts, 1)) # Compute distance between each input vector and each cluster center. distances = ( tf.expand_dims(tf.reduce_sum(x_flat**2, axis=1), 1) - 2 * tf.matmul(x_flat, tf.transpose(e)) + tf.expand_dims(tf.reduce_sum(e**2, axis=1), 0) ) # Find nearest cluster center for each input vector. c = tf.argmin(distances, axis=1) # Quantize input vectors with straight-through estimator. z = tf.nn.embedding_lookup(e, c) z_split = tf.split(z, self.num_heads, axis=0) z = tf.concat(z_split, axis=1) z = tf.reshape(z, tf.shape(x)) z = x + tf.stop_gradient(z - x) if training: # Compute cluster counts and vector sums over the batch. oh = tf.one_hot(indices=c, depth=self.k) counts = tf.reduce_sum(oh, axis=0) sums = tf.matmul(oh, x_flat, transpose_a=True) # Apply exponential moving average to cluster counts and vector sums. self.counts.assign_sub((1 - self.gamma) * (self.counts - counts)) self.sums.assign_sub((1 - self.gamma) * (self.sums - sums)) c_split = tf.split(c, self.num_heads, axis=0) c = tf.stack(c_split, axis=1) c = tf.reshape(c, tf.concat([tf.shape(x)[:-1], [self.num_heads]], axis=0)) return z, c
def _mode(self): logits = self._logits_parameter_no_checks() ret = tf.one_hot( tf.argmax(logits, axis=-1), self._event_size(logits), dtype=self.dtype) tensorshape_util.set_shape(ret, logits.shape) return ret
def helper_test_keras_v2_gradienttape(script_mode: bool = False, json_file_contents="{}"): """ Test the default ZCC behavior of saving losses and metrics in eager and non-eager modes.""" smd.del_hook() tf.keras.backend.clear_session() with SagemakerSimulator(json_file_contents=json_file_contents) as sim: model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28, 1)), # WA for TF issue #36279 tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation="softmax"), ]) (x_train, y_train), _ = get_keras_data() dataset = tf.data.Dataset.from_tensor_slices( (tf.cast(x_train[..., tf.newaxis] / 255, tf.float32), tf.cast(y_train, tf.int64))) dataset = dataset.shuffle(1000).batch(64) opt = tf.keras.optimizers.RMSprop() cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True) train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy() n_epochs = 2 if script_mode: if json_file_contents == "{}": hook = smd.KerasHook(out_dir=sim.out_dir, export_tensorboard=True) else: hook = smd.KerasHook.create_from_json_file() for epoch in range(n_epochs): print("Epoch %d/%d" % (epoch + 1, n_epochs)) for data, labels in dataset: dataset_labels = labels labels = tf.one_hot(labels, depth=10) with hook.wrap_tape(tf.GradientTape()) as tape: logits = model(data, training=True) # (32,10) loss_value = cce(labels, logits) grads = tape.gradient(loss_value, model.variables) opt.apply_gradients(zip(grads, model.variables)) acc = train_acc_metric(dataset_labels, logits) hook.record_tensor_value(tensor_name="accuracy", tensor_value=acc) log = "Epoch %d " % (epoch + 1) log += "Accuracy %.4f" % train_acc_metric.result() print(log) train_acc_metric.reset_states() hook = smd.get_hook() assert hook hook.close() # Check that hook created and tensors saved trial = smd.create_trial(path=sim.out_dir) assert len(trial.steps()) > 0, "Nothing saved at any step." assert len(trial.tensor_names()) > 0, "Tensors were not saved." assert len(trial.tensor_names(collection="losses")) > 0 else: # ZCC support added from smdebug v0.8.0) for epoch in range(n_epochs): print("Epoch %d/%d" % (epoch + 1, n_epochs)) for data, labels in dataset: dataset_labels = labels labels = tf.one_hot(labels, depth=10) with tf.GradientTape(persistent=True) as tape: logits = model(data, training=True) # (32,10) loss_value = cce(labels, logits) grads = tape.gradient(loss_value, model.variables) opt.apply_gradients(zip(grads, model.variables)) acc = train_acc_metric(dataset_labels, logits) log = "Epoch %d " % (epoch + 1) log += "Accuracy %.4f" % train_acc_metric.result() print(log) train_acc_metric.reset_states() hook = smd.get_hook() if not is_tf_2_2(): assert not hook # only supported on TF 2.2 and greater return assert hook hook.close() # Check that hook created and tensors saved trial = smd.create_trial(path=sim.out_dir) assert len(trial.steps()) > 0, "Nothing saved at any step." assert len(trial.tensor_names()) > 0, "Tensors were not saved." assert len(trial.tensor_names(collection="losses")) > 0
def _sample_n(self, n, seed=None): with tf.control_dependencies(self._runtime_assertions): seed = seed_stream.SeedStream(seed, salt="HiddenMarkovModel") num_states = self._num_states batch_shape = self.batch_shape_tensor() batch_size = tf.reduce_prod(input_tensor=batch_shape) # The batch sizes of the underlying initial distributions and # transition distributions might not match the batch size of # the HMM distribution. # As a result we need to ask for more samples from the # underlying distributions and then reshape the results into # the correct batch size for the HMM. init_repeat = ( tf.reduce_prod(input_tensor=self.batch_shape_tensor()) // tf.reduce_prod(input_tensor=self._initial_distribution. batch_shape_tensor())) init_state = self._initial_distribution.sample(n * init_repeat, seed=seed()) init_state = tf.reshape(init_state, [n, batch_size]) # init_state :: n batch_size transition_repeat = ( tf.reduce_prod(input_tensor=self.batch_shape_tensor()) // tf.reduce_prod(input_tensor=self._transition_distribution. batch_shape_tensor()[:-1])) def generate_step(state, _): """Take a single step in Markov chain.""" gen = self._transition_distribution.sample(n * transition_repeat, seed=seed()) # gen :: (n * transition_repeat) transition_batch new_states = tf.reshape(gen, [n, batch_size, num_states]) # new_states :: n batch_size num_states old_states_one_hot = tf.one_hot(state, num_states, dtype=tf.int32) # old_states :: n batch_size num_states return tf.reduce_sum(input_tensor=old_states_one_hot * new_states, axis=-1) if self._num_steps > 1: dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state) # TODO(b/115618503): add/use prepend_initializer to tf.scan hidden_states = tf.concat([[init_state], hidden_states], axis=0) else: hidden_states = init_state[tf.newaxis, ...] # hidden_states :: num_steps n batch_size num_states hidden_one_hot = tf.one_hot( hidden_states, num_states, dtype=self._observation_distribution.dtype) # hidden_one_hot :: num_steps n batch_size num_states # The observation distribution batch size might not match # the required batch size so as with the initial and # transition distributions we generate more samples and # reshape. observation_repeat = ( batch_size // tf.reduce_prod(input_tensor=self._observation_distribution. batch_shape_tensor()[:-1])) possible_observations = self._observation_distribution.sample( [self._num_steps, observation_repeat * n]) inner_shape = self._observation_distribution.event_shape # possible_observations :: num_steps (observation_repeat * n) # observation_batch[:-1] num_states inner_shape possible_observations = tf.reshape( possible_observations, tf.concat([[self._num_steps, n], batch_shape, [num_states], inner_shape], axis=0)) # possible_observations :: steps n batch_size num_states inner_shape hidden_one_hot = tf.reshape( hidden_one_hot, tf.concat([[self._num_steps, n], batch_shape, [num_states], tf.ones_like(inner_shape)], axis=0)) # hidden_one_hot :: steps n batch_size num_states "inner_shape" observations = tf.reduce_sum(input_tensor=hidden_one_hot * possible_observations, axis=-1 - tf.size(input=inner_shape)) # observations :: steps n batch_size inner_shape observations = distribution_util.move_dimension( observations, 0, 1 + tf.size(input=batch_shape)) # returned :: n batch_shape steps inner_shape return observations
def helper_keras_gradtape( trial_dir, save_all=False, include_collections=None, reduction_config=None, save_config=None, hook=None, batch_size=64, persistent=False, ): mnist = tf.keras.datasets.mnist (x_train, y_train), _ = mnist.load_data() dataset = tf.data.Dataset.from_tensor_slices( (tf.cast(x_train[..., tf.newaxis] / 255, tf.float32), tf.cast(y_train, tf.int64))) dataset = dataset.shuffle(1000).batch(batch_size) model = tf.keras.models.Sequential([ # WA for TF issue https://github.com/tensorflow/tensorflow/issues/36279 tf.keras.layers.Flatten(input_shape=(28, 28, 1)), tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation="softmax"), ]) if hook is None: if save_config is None: save_config = SaveConfig(save_interval=3) hook = smd.KerasHook( trial_dir, save_config=save_config, save_all=save_all, include_collections=include_collections, reduction_config=reduction_config, ) if not save_all and include_collections is not None: for cname in hook.include_collections: if cname not in include_collections: hook.get_collection(cname).save_config = SaveConfig( end_step=0) opt = tf.keras.optimizers.Adam() hook.wrap_optimizer(opt) cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True) train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy() n_epochs = 1 for epoch in range(n_epochs): for data, labels in dataset: dataset_labels = labels labels = tf.one_hot(labels, depth=10) with hook.wrap_tape( tf.GradientTape(persistent=persistent)) as tape: logits = model(data, training=True) # (32,10) loss_value = cce(labels, logits) grads = tape.gradient(loss_value, model.variables) # By default, the resources held by a GradientTape are released as # soon as GradientTape.gradient() method is called. To compute # multiple gradients over the same computation, create a persistent # gradient tape. This allows multiple calls to the gradient() method # as resources are released when the tape object is garbage collected. if persistent: _ = tape.gradient(loss_value, model.variables) opt.apply_gradients(zip(grads, model.variables)) acc = train_acc_metric(dataset_labels, logits) hook.record_tensor_value(tensor_name="accuracy", tensor_value=acc) train_acc_metric.reset_states() hook.close()
def _one_hot_encoding_label(wav, label): return wav, tf.one_hot(label, num_classes)
def train_step(self, dataset: dataset_lib.OffpolicyDataset, target_policy: tf_policy.TFPolicy, regularizer: float = 1e-6): """Performs single iteration of CoinDICE. Args: dataset: The dataset to sample experience from. target_policy: The policy whose value we want to estimate. regularizer: A small constant to add to matrices before inverting them or to floats before taking square root. Returns: Estimated average per-step reward of the target policy. """ # First compute Lagrangian loss. saddle_bellman_residuals = (tf.matmul(self._a_vec, self._nu) - self._weighted_rewards[:, None]) saddle_bellman_residuals *= -1 * self._algae_alpha_sign saddle_zetas = tf.gather(self._zeta, self._nu_indices) saddle_initial_nu_values = tf.reduce_sum( # Average over actions. self._initial_target_probs[:, :, None] * tf.gather(self._nu, self._initial_nu_indices), axis=1) saddle_init_nu_loss = ((1 - self._gamma) * saddle_initial_nu_values * self._algae_alpha_sign) # This second optimization switches the sign of algae_alpha. # We add these two together to get the final loss, and thus counteract # the bias introduced by algae_alpha. saddle_bellman_residuals2 = (tf.matmul(self._a_vec, self._nu2) - self._weighted_rewards[:, None]) saddle_bellman_residuals2 *= 1 * self._algae_alpha_sign saddle_zetas2 = tf.gather(self._zeta2, self._nu_indices) saddle_initial_nu_values2 = tf.reduce_sum( # Average over actions. self._initial_target_probs[:, :, None] * tf.gather(self._nu2, self._initial_nu_indices), axis=1) saddle_init_nu_loss2 = ((1 - self._gamma) * saddle_initial_nu_values2 * -1 * self._algae_alpha_sign) saddle_loss = 0.5 * ( saddle_init_nu_loss + saddle_bellman_residuals * saddle_zetas + -tf.math.abs(self._algae_alpha) * 0.5 * tf.square(saddle_zetas) + -saddle_init_nu_loss2 + -saddle_bellman_residuals2 * saddle_zetas2 + tf.math.abs(self._algae_alpha) * 0.5 * tf.square(saddle_zetas2)) # Find optimal weights by doing binary search on alpha (lambda in the # paper). left = tf.constant([-8., -8.]) right = tf.constant([32., 32.]) for _ in range(16): mid = 0.5 * (left + right) self._alpha.assign(mid) weights, log_weights = self._get_weights(saddle_loss) divergence = self._compute_divergence(weights, log_weights) divergence_violation = divergence - self._two_sided_limit left = tf.where(divergence_violation > 0., mid, left) right = tf.where(divergence_violation > 0., right, mid) self._alpha.assign(0.5 * (left + right)) weights, log_weights = self._get_weights(saddle_loss) # Now that we have weights, we reconstruct the Bellman residual matrices. data_weights = tf.stop_gradient(weights) avg_saddle_loss = (tf.reduce_sum(data_weights * saddle_loss, axis=0) / tf.reduce_sum(data_weights, axis=0)) weighted_state_action_count = tf.reduce_sum( tf.one_hot(self._nu_indices, self._dimension)[:, :, None] * weights[:, None, :], axis=0) weighted_state_action_count = tf.gather(weighted_state_action_count, self._nu_indices) my_td_mat = tf.einsum('ai, ab, ab, aj -> bij', tf.one_hot(self._nu_indices, self._dimension), 1.0 / weighted_state_action_count, weights, self._a_vec) my_bias = tf.reduce_sum( tf.transpose(weights)[:, :, None] * tf.one_hot(self._nu_indices, self._dimension)[None, :, :] * tf.reshape(self._weighted_rewards, [1, -1, 1]) * 1.0 / tf.transpose(weighted_state_action_count)[:, :, None], axis=1) # Solve for nu using primal form; i.e., E[(nu - B nu)^2] - (1-g) * E[nu0]. with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape: tape.watch([self._nu, self._nu2, self._alpha]) bellman_residuals = tf.matmul( my_td_mat, tf.transpose(self._nu)[:, :, None]) - my_bias[:, :, None] bellman_residuals = tf.transpose(tf.squeeze(bellman_residuals, -1)) bellman_residuals = tf.gather(bellman_residuals, self._nu_indices) initial_nu_values = tf.reduce_sum( # Average over actions. self._initial_target_probs[:, :, None] * tf.gather(self._nu, self._initial_nu_indices), axis=1) bellman_residuals *= self._algae_alpha_sign init_nu_loss = ((1 - self._gamma) * initial_nu_values * self._algae_alpha_sign) nu_loss = (tf.math.square(bellman_residuals) / 2.0 + tf.math.abs(self._algae_alpha) * init_nu_loss) loss = (data_weights * nu_loss / tf.reduce_sum(data_weights, axis=0, keepdims=True)) bellman_residuals2 = tf.matmul( my_td_mat, tf.transpose(self._nu2)[:, :, None]) - my_bias[:, :, None] bellman_residuals2 = tf.transpose( tf.squeeze(bellman_residuals2, -1)) bellman_residuals2 = tf.gather(bellman_residuals2, self._nu_indices) initial_nu_values2 = tf.reduce_sum( # Average over actions. self._initial_target_probs[:, :, None] * tf.gather(self._nu2, self._initial_nu_indices), axis=1) bellman_residuals2 *= -1 * self._algae_alpha_sign init_nu_loss2 = ((1 - self._gamma) * initial_nu_values2 * -1 * self._algae_alpha_sign) nu_loss2 = (tf.math.square(bellman_residuals2) / 2.0 + tf.math.abs(self._algae_alpha) * init_nu_loss2) loss2 = (data_weights * nu_loss2 / tf.reduce_sum(data_weights, axis=0, keepdims=True)) divergence = self._compute_divergence(weights, log_weights) divergence_violation = divergence - self._two_sided_limit # Extra loss if for the 'terminal' state (index = -1). extra_loss = tf.reduce_sum(tf.math.square(self._nu[-1, :])) extra_loss2 = tf.reduce_sum(tf.math.square(self._nu2[-1, :])) nu_grad = tape.gradient(loss + extra_loss, [self._nu])[0] nu_grad2 = tape.gradient(loss2 + extra_loss2, [self._nu2])[0] avg_loss = tf.reduce_sum(0.5 * (loss - loss2) / tf.math.abs(self._algae_alpha), axis=0) nu_jacob = tape.jacobian(nu_grad, [self._nu])[0] nu_hess = tf.stack( [nu_jacob[:, i, :, i] for i in range(self._num_limits)], axis=0) nu_jacob2 = tape.jacobian(nu_grad2, [self._nu2])[0] nu_hess2 = tf.stack( [nu_jacob2[:, i, :, i] for i in range(self._num_limits)], axis=0) for idx, div in enumerate(divergence): tf.summary.scalar('divergence%d' % idx, div) # Perform Newton step on nu. nu_transformed = tf.transpose( tf.squeeze( tf.linalg.solve( nu_hess + regularizer * tf.eye(self._dimension), tf.expand_dims(-tf.transpose(nu_grad), axis=-1)))) self._nu = self._nu + self._nu_learning_rate * nu_transformed nu_transformed2 = tf.transpose( tf.squeeze( tf.linalg.solve( nu_hess2 + regularizer * tf.eye(self._dimension), tf.expand_dims(-tf.transpose(nu_grad2), axis=-1)))) self._nu2 = self._nu2 + self._nu_learning_rate * nu_transformed2 # Perform step on zeta based on fact that zeta* = (nu* - bellman nu*)/a. zetas = tf.matmul(my_td_mat, tf.transpose(self._nu)[:, :, None]) - my_bias[:, :, None] zetas = tf.transpose(tf.squeeze(zetas, -1)) zetas *= -self._algae_alpha_sign zetas /= tf.math.abs(self._algae_alpha) self._zeta = self._zeta + self._zeta_learning_rate * (zetas - self._zeta) zetas2 = tf.matmul(my_td_mat, tf.transpose(self._nu2)[:, :, None]) - my_bias[:, :, None] zetas2 = tf.transpose(tf.squeeze(zetas2, -1)) zetas2 *= 1 * self._algae_alpha_sign zetas2 /= tf.math.abs(self._algae_alpha) self._zeta2 = (self._zeta2 + self._zeta_learning_rate * (zetas2 - self._zeta2)) return [ avg_saddle_loss * self._algae_alpha_sign, avg_loss * self._algae_alpha_sign, divergence ]
def _sample_n(self, n, seed=None): strm = SeedStream(seed, salt='HiddenMarkovModel') transition_batch_shape = self.transition_distribution.batch_shape_tensor( ) num_states = transition_batch_shape[-1] batch_shape = self.batch_shape_tensor() batch_size = tf.reduce_prod(batch_shape) # The batch sizes of the underlying initial distributions and # transition distributions might not match the batch size of # the HMM distribution. # As a result we need to ask for more samples from the # underlying distributions and then reshape the results into # the correct batch size for the HMM. init_repeat = ( tf.reduce_prod(batch_shape) // tf.reduce_prod(self._initial_distribution.batch_shape_tensor())) init_state = self._initial_distribution.sample(n * init_repeat, seed=strm()) init_state = tf.reshape(init_state, [n, batch_size]) # init_state :: n batch_size transition_repeat = (tf.reduce_prod(batch_shape) // tf.reduce_prod(transition_batch_shape[:-1])) init_shape = init_state.shape def generate_step(state, _): """Take a single step in Markov chain.""" gen = self._transition_distribution.sample(n * transition_repeat, seed=strm()) # gen :: (n * transition_repeat) transition_batch new_states = tf.reshape(gen, [n, batch_size, num_states]) # new_states :: n batch_size num_states old_states_one_hot = tf.one_hot(state, num_states, dtype=tf.int32) # old_states :: n batch_size num_states result = tf.reduce_sum(old_states_one_hot * new_states, axis=-1) # We know that `generate_step` must preserve the shape of the # tensor of states of each state. This is because # the transition matrix must be square. But TensorFlow might # not know this so we explicitly tell it that the result has the # same shape. result.set_shape(init_shape) return result def _scan_multiple_steps(): """Take multiple steps with tf.scan.""" dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) if seed is not None: # Force parallel_iterations to 1 to ensure reproducibility # b/139210489 hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state, parallel_iterations=1) else: # Invoke default parallel_iterations behavior hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state) # TODO(b/115618503): add/use prepend_initializer to tf.scan return tf.concat([[init_state], hidden_states], axis=0) hidden_states = prefer_static.cond(self._num_steps > 1, _scan_multiple_steps, lambda: init_state[tf.newaxis, ...]) hidden_one_hot = tf.one_hot(hidden_states, num_states, dtype=self._observation_distribution.dtype) # hidden_one_hot :: num_steps n batch_size num_states # The observation distribution batch size might not match # the required batch size so as with the initial and # transition distributions we generate more samples and # reshape. observation_repeat = (batch_size // tf.reduce_prod( self._observation_distribution.batch_shape_tensor()[:-1])) possible_observations = self._observation_distribution.sample( [self._num_steps, observation_repeat * n], seed=strm()) inner_shape = self._observation_distribution.event_shape_tensor() # possible_observations :: num_steps (observation_repeat * n) # observation_batch[:-1] num_states inner_shape possible_observations = tf.reshape( possible_observations, tf.concat( [[self._num_steps, n], batch_shape, [num_states], inner_shape], axis=0)) # possible_observations :: steps n batch_size num_states inner_shape hidden_one_hot = tf.reshape( hidden_one_hot, tf.concat([[self._num_steps, n], batch_shape, [num_states], tf.ones_like(inner_shape)], axis=0)) # hidden_one_hot :: steps n batch_size num_states "inner_shape" observations = tf.reduce_sum(hidden_one_hot * possible_observations, axis=-1 - tf.size(inner_shape)) # observations :: steps n batch_size inner_shape observations = distribution_util.move_dimension( observations, 0, 1 + tf.size(batch_shape)) # returned :: n batch_shape steps inner_shape return observations
def format_example(imgs, labels): """Formats each training and test example to work with our model.""" imgs = tf.reshape(imgs, [-1, 28 * 28]) imgs = tf.cast(imgs, tf.float32) / 255.0 labels = tf.one_hot(labels, depth=10, dtype=tf.float32) return imgs, labels
def testShardedState(self): if not JAX_MODE: self.skipTest('b/181800108') num_burnin_steps = 1000 num_adaptation_steps = int(num_burnin_steps * 0.8) num_results = 500 num_chains = 64 step_size = 1e-2 num_mala_steps = 100 def trace_fn(_, pkr): return { 'step_size': unnest.get_innermost(pkr, 'step_size'), 'mean_trajectory_length': unnest.get_innermost(pkr, 'max_trajectory_length') / 2., 'principal_component': unnest.get_innermost(pkr, 'ema_principal_component'), 'variance': unnest.get_innermost(pkr, 'ema_variance'), 'num_leapfrog_steps': unnest.get_innermost(pkr, 'num_leapfrog_steps'), } init_x = ([ self.shard_values( tf.zeros((distribute_test_lib.NUM_DEVICES, num_chains))) ] * 2) local_scale = self.shard_values( 1. + tf.one_hot(0, distribute_test_lib.NUM_DEVICES)) @tf.function(autograph=False) def run(init_x, local_scale): @tfp.experimental.distribute.JointDistributionCoroutine def model(): yield tfd.Normal(0., 1.) yield tfp.experimental.distribute.Sharded( tfd.Normal(0., local_scale), shard_axis_name=self.axis_name) kernel = tfp.experimental.mcmc.SNAPERHamiltonianMonteCarlo( model.log_prob, step_size=step_size, num_adaptation_steps=num_adaptation_steps, num_mala_steps=num_mala_steps, experimental_shard_axis_names=list( model.experimental_shard_axis_names), ) kernel = tfp.mcmc.DualAveragingStepSizeAdaptation( kernel, num_adaptation_steps=num_adaptation_steps, ) return tfp.mcmc.sample_chain( num_results=num_burnin_steps + num_results, num_burnin_steps=0, current_state=init_x, kernel=kernel, trace_fn=trace_fn, seed=test_util.test_seed(sampler_type='stateless')) _, trace = self.evaluate( self.per_replica_to_tensor( self.strategy_run( run, args=(init_x, local_scale), axis_name=self.axis_name, ))) self.assertAllClose(0., trace['principal_component'][0][0, -1], atol=0.1) expected_local_principal_component = np.zeros( distribute_test_lib.NUM_DEVICES) expected_local_principal_component[0] = 1. self.assertAllClose(expected_local_principal_component, trace['principal_component'][1][:, -1], atol=0.1) self.assertAllClose(1., trace['variance'][0][0, -1], atol=0.1) expected_local_variance = np.ones(distribute_test_lib.NUM_DEVICES) expected_local_variance[0] = 4. self.assertAllClose(expected_local_variance, trace['variance'][1][:, -1], rtol=0.2) # Shard consistency. self.assertAllClose(trace['step_size'][0], trace['step_size'][1]) self.assertAllClose(trace['mean_trajectory_length'][0], trace['mean_trajectory_length'][1])
def _sample_channels(self, component_logits, locs, scales, coeffs=None, seed=None): """Sample a single pixel-iteration and apply channel conditioning. Args: component_logits: 4D `Tensor` of logits for the Categorical distribution over Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix]`. locs: 4D `Tensor` of location parameters for the Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix, num_channels]`. scales: 4D `Tensor` of location parameters for the Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix, num_channels]`. coeffs: 4D `Tensor` of coefficients for the linear dependence among color channels, or `None` if there is only one channel. Dimensions are `[batch_size, height, width, num_logistic_mix, num_coeffs]`, where `num_coeffs = num_channels * (num_channels - 1) // 2`. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: samples: 4D `Tensor` of sampled image data with autoregression among channels. Dimensions are `[batch_size, height, width, num_channels]`. """ num_channels = self.event_shape[-1] # sample mixture components once for the entire pixel component_dist = categorical.Categorical(logits=component_logits) mask = tf.one_hot(indices=component_dist.sample(seed=seed), depth=self._num_logistic_mix) mask = tf.cast(mask[..., tf.newaxis], self.dtype) # apply mixture component mask and separate out RGB parameters masked_locs = tf.reduce_sum(locs * mask, axis=-2) loc_tensors = tf.split(masked_locs, num_channels, axis=-1) masked_scales = tf.reduce_sum(scales * mask, axis=-2) scale_tensors = tf.split(masked_scales, num_channels, axis=-1) if coeffs is not None: num_coeffs = num_channels * (num_channels - 1) // 2 masked_coeffs = tf.reduce_sum(coeffs * mask, axis=-2) coef_tensors = tf.split(masked_coeffs, num_coeffs, axis=-1) channel_samples = [] coef_count = 0 for i in range(num_channels): loc = loc_tensors[i] for c in channel_samples: loc += c * coef_tensors[coef_count] coef_count += 1 logistic_samp = logistic.Logistic( loc=loc, scale=scale_tensors[i]).sample(seed=seed) logistic_samp = tf.clip_by_value(logistic_samp, -1., 1.) channel_samples.append(logistic_samp) return tf.concat(channel_samples, axis=-1)
def _sample_n(self, n, seed): components_seed, mix_seed = samplers.split_seed( seed, salt='MixtureSameFamily') try: seed_stream = SeedStream(seed, salt='MixtureSameFamily') except TypeError as e: # Can happen for Tensor seeds. seed_stream = None seed_stream_err = e try: x = self.components_distribution.sample( # [n, B, k, E] n, seed=components_seed) if seed_stream is not None: seed_stream() # Advance even if unused. except TypeError as e: if ('Expected int for argument' not in str(e) and TENSOR_SEED_MSG_PREFIX not in str(e)): raise if seed_stream is None: raise seed_stream_err msg = ( 'Falling back to stateful sampling for `components_distribution` ' '{} of type `{}`. Please update to use `tf.random.stateless_*` ' 'RNGs. This fallback may be removed after 20-Aug-2020. {}') warnings.warn( msg.format(self.components_distribution.name, type(self.components_distribution), str(e))) x = self.components_distribution.sample( # [n, B, k, E] n, seed=seed_stream()) event_shape = None event_ndims = tensorshape_util.rank(self.event_shape) if event_ndims is None: event_shape = self.components_distribution.event_shape_tensor() event_ndims = prefer_static.rank_from_shape(event_shape) event_ndims_static = tf.get_static_value(event_ndims) num_components = None if event_ndims_static is not None: num_components = tf.compat.dimension_value( x.shape[-1 - event_ndims_static]) # We could also check if num_components can be computed statically from # self.mixture_distribution's logits or probs. if num_components is None: num_components = tf.shape(x)[-1 - event_ndims] # TODO(jvdillon): Consider using tf.gather (by way of index unrolling). npdt = dtype_util.as_numpy_dtype(x.dtype) try: mix_sample = self.mixture_distribution.sample( n, seed=mix_seed) # [n, B] or [n] except TypeError as e: if ('Expected int for argument' not in str(e) and TENSOR_SEED_MSG_PREFIX not in str(e)): raise if seed_stream is None: raise seed_stream_err msg = ( 'Falling back to stateful sampling for `mixture_distribution` ' '{} of type `{}`. Please update to use `tf.random.stateless_*` ' 'RNGs. This fallback may be removed after 20-Aug-2020. ({})') warnings.warn( msg.format(self.mixture_distribution.name, type(self.mixture_distribution), str(e))) mix_sample = self.mixture_distribution.sample( n, seed=seed_stream()) # [n, B] or [n] mask = tf.one_hot( indices=mix_sample, # [n, B] or [n] depth=num_components, on_value=npdt(1), off_value=npdt(0)) # [n, B, k] or [n, k] # Pad `mask` to [n, B, k, [1]*e] or [n, [1]*b, k, [1]*e] . batch_ndims = prefer_static.rank(x) - event_ndims - 1 mask_batch_ndims = prefer_static.rank(mask) - 1 pad_ndims = batch_ndims - mask_batch_ndims mask_shape = prefer_static.shape(mask) mask = tf.reshape( mask, shape=prefer_static.concat([ mask_shape[:-1], prefer_static.ones([pad_ndims], dtype=tf.int32), mask_shape[-1:], prefer_static.ones([event_ndims], dtype=tf.int32), ], axis=0)) if x.dtype in [ tf.bfloat16, tf.float16, tf.float32, tf.float64, tf.complex64, tf.complex128 ]: masked = tf.math.multiply_no_nan(x, mask) else: masked = x * mask ret = tf.reduce_sum(masked, axis=-1 - event_ndims) # [n, B, E] if self._reparameterize: if event_shape is None: event_shape = self.components_distribution.event_shape_tensor() ret = self._reparameterize_sample(ret, event_shape=event_shape) return ret
def convert(self, image, label): image = tf.cast(image, self.dtype) image = image / tf.cast(255.0, dtype=self.dtype) label = tf.one_hot(label, depth=self.output_size) return image, tf.cast(label, tf.int32)
def update(self, expert_dataset_iter, policy_dataset_iter, discount, replay_regularization=0.05, nu_reg=10.0): """A function that updates nu network. When replay regularization is non-zero, it learns (d_pi * (1 - replay_regularization) + d_rb * replay_regulazation) / (d_expert * (1 - replay_regularization) + d_rb * replay_regulazation) instead. Args: expert_dataset_iter: An tensorflow graph iteratable over expert data. policy_dataset_iter: An tensorflow graph iteratable over training policy data, used for regularization. discount: An MDP discount. replay_regularization: A fraction of samples to add from a replay buffer. nu_reg: A grad penalty regularization coefficient. """ (expert_states, expert_actions, expert_next_states) = expert_dataset_iter.get_next() expert_initial_states = expert_states # rb_states, rb_actions, rb_next_states, _, _ = policy_dataset_iter.get_next( # )[0] with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape: tape.watch(self.actor.variables) tape.watch(self.nu_net.variables) _, policy_next_actions, _ = self.actor(expert_next_states) # _, rb_next_actions, rb_log_prob = self.actor(rb_next_states) _, policy_initial_actions, _ = self.actor(expert_initial_states) # Inputs for the linear part of DualDICE loss. expert_init_inputs = tf.concat( [expert_initial_states, policy_initial_actions], 1) if not self.discrete: expert_inputs = tf.concat([expert_states, expert_actions], 1) else: mat = tf.one_hot(tf.cast(expert_actions, tf.int32), depth=self.action_dim, axis=-1) expert_inputs = tf.concat([expert_states, mat], 1) expert_next_inputs = tf.concat( [expert_next_states, policy_next_actions], 1) # rb_inputs = tf.concat([rb_states, rb_actions], 1) # rb_next_inputs = tf.concat([rb_next_states, rb_next_actions], 1) expert_nu_0 = self.nu_net(expert_init_inputs) expert_nu = self.nu_net(expert_inputs) expert_nu_next = self.nu_net(expert_next_inputs) # rb_nu = self.nu_net(rb_inputs) # rb_nu_next = self.nu_net(rb_next_inputs) expert_diff = expert_nu - discount * expert_nu_next # rb_diff = rb_nu - discount * rb_nu_next linear_loss_expert = tf.reduce_mean(expert_nu_0 * (1 - discount)) # linear_loss_rb = tf.reduce_mean(rb_diff) rb_expert_diff = expert_diff #tf.concat([expert_diff, rb_diff], 0) rb_expert_weights = tf.ones(expert_diff.shape) #tf.concat([ # tf.ones(expert_diff.shape) * (1 - replay_regularization), # tf.ones(rb_diff.shape) * replay_regularization # ], 0) rb_expert_weights /= tf.reduce_sum(rb_expert_weights) non_linear_loss = tf.reduce_sum( tf.stop_gradient( weighted_softmax(rb_expert_diff, rb_expert_weights, axis=0)) * rb_expert_diff) linear_loss = (linear_loss_expert * (1 - replay_regularization) + 0) # linear_loss_rb * replay_regularization) loss = (non_linear_loss - linear_loss) alpha = tf.random.uniform(shape=(expert_inputs.shape[0], 1)) # nu_inter = alpha * expert_inputs + (1 - alpha) * expert_init_inputs #rb_inputs # nu_next_inter = alpha * expert_next_inputs + (1 - alpha) * #rb_next_inputs # nu_inter = tf.concat([nu_inter, nu_next_inter], 0) nu_inter = alpha * expert_inputs + (1 - alpha) * tf.stop_gradient( tf.random.shuffle(expert_next_inputs)) with tf.GradientTape(watch_accessed_variables=False) as tape2: tape2.watch(nu_inter) nu_output = self.nu_net(nu_inter) nu_grad = tape2.gradient(nu_output, [nu_inter])[0] + EPS nu_grad_penalty = tf.reduce_mean( tf.square(tf.norm(nu_grad, axis=-1, keepdims=True) - 1)) nu_loss = loss + nu_grad_penalty * nu_reg pi_loss = -loss + keras_utils.orthogonal_regularization( self.actor.trunk) nu_grads = tape.gradient(nu_loss, self.nu_net.variables) pi_grads = tape.gradient(pi_loss, self.actor.variables) self.nu_optimizer.apply_gradients(zip(nu_grads, self.nu_net.variables)) self.actor_optimizer.apply_gradients( zip(pi_grads, self.actor.variables)) del tape self.avg_nu_expert(expert_nu) #self.avg_nu_rb(rb_nu) self.nu_reg_metric(nu_grad_penalty) self.avg_loss(loss) self.avg_actor_loss(pi_loss) #self.avg_actor_entropy(-rb_log_prob) if tf.equal(self.nu_optimizer.iterations % self.log_interval, 0): tf.summary.scalar('train dual dice/loss', self.avg_loss.result(), step=self.nu_optimizer.iterations) keras_utils.my_reset_states(self.avg_loss) tf.summary.scalar('train dual dice/nu expert', self.avg_nu_expert.result(), step=self.nu_optimizer.iterations) keras_utils.my_reset_states(self.avg_nu_expert) tf.summary.scalar('train dual dice/nu rb', self.avg_nu_rb.result(), step=self.nu_optimizer.iterations) keras_utils.my_reset_states(self.avg_nu_rb) tf.summary.scalar('train dual dice/nu reg', self.nu_reg_metric.result(), step=self.nu_optimizer.iterations) keras_utils.my_reset_states(self.nu_reg_metric) if tf.equal(self.actor_optimizer.iterations % self.log_interval, 0): tf.summary.scalar('train sac/actor_loss', self.avg_actor_loss.result(), step=self.actor_optimizer.iterations) keras_utils.my_reset_states(self.avg_actor_loss) tf.summary.scalar('train sac/actor entropy', self.avg_actor_entropy.result(), step=self.actor_optimizer.iterations) keras_utils.my_reset_states(self.avg_actor_entropy)
def boolean_mask(boxlist, indicator, fields=None, scope=None, use_static_shapes=False, indicator_sum=None): """Select boxes from BoxList according to indicator and return new BoxList. `boolean_mask` returns the subset of boxes that are marked as "True" by the indicator tensor. By default, `boolean_mask` returns boxes corresponding to the input index list, as well as all additional fields stored in the boxlist (indexing into the first dimension). However one can optionally only draw from a subset of fields. Args: boxlist: BoxList holding N boxes indicator: a rank-1 boolean tensor fields: (optional) list of fields to also gather from. If None (default), all fields are gathered from. Pass an empty fields list to only gather the box coordinates. scope: name scope. use_static_shapes: Whether to use an implementation with static shape gurantees. indicator_sum: An integer containing the sum of `indicator` vector. Only required if `use_static_shape` is True. Returns: subboxlist: a BoxList corresponding to the subset of the input BoxList specified by indicator Raises: ValueError: if `indicator` is not a rank-1 boolean tensor. """ with tf.name_scope(scope, 'BooleanMask'): if indicator.shape.ndims != 1: raise ValueError('indicator should have rank 1') if indicator.dtype != tf.bool: raise ValueError('indicator should be a boolean tensor') if use_static_shapes: if not (indicator_sum and isinstance(indicator_sum, int)): raise ValueError('`indicator_sum` must be a of type int') selected_positions = tf.cast(indicator, dtype=tf.float32) indexed_positions = tf.cast(tf.multiply( tf.cumsum(selected_positions), selected_positions), dtype=tf.int32) one_hot_selector = tf.one_hot(indexed_positions - 1, indicator_sum, dtype=tf.float32) sampled_indices = tf.cast(tf.tensordot(tf.cast(tf.range( tf.shape(indicator)[0]), dtype=tf.float32), one_hot_selector, axes=[0, 0]), dtype=tf.int32) return gather(boxlist, sampled_indices, use_static_shapes=True) else: subboxlist = box_list.BoxList( tf.boolean_mask(boxlist.get(), indicator)) if fields is None: fields = boxlist.get_extra_fields() for field in fields: if not boxlist.has_field(field): raise ValueError( 'boxlist must contain all specified fields') subfieldlist = tf.boolean_mask(boxlist.get_field(field), indicator) subboxlist.add_field(field, subfieldlist) return subboxlist
def _binary_crossover(population, population_size, mutants, crossover_prob, seed): """Performs recombination by binary crossover for the current population. Let v_i denote the i'th component of the member v and m_i the corresponding component of the mutant vector corresponding to v. Then the crossed over vector w_i is determined by setting w_i = (m_i with probability=crossover_prob else v_i). In addition, DE requires that at least one of the components is crossed over (otherwise we end up with no change). This is done by choosing on index say k randomly where a force crossover is performed (i.e. w_k = m_k). This is the scheme implemented in this function. Args: population: A Python list of `Tensor`s where each `Tensor` in the list must be of rank at least 1 and all the elements must have a common first dimension. The base population to cross over. population_size: A scalar integer `Tensor`. The number of elements in the population (i.e. size of the first dimension of any member of `population`). mutants: A Python list of `Tensor`s with the same structure as `population`. The mutated population. crossover_prob: A positive real scalar `Tensor` bounded above by 1.0. The probability of a crossover being performed for each axis. seed: `int` or None. The random seed for this `Op`. If `None`, no seed is applied. Returns: A list of `Tensor`s of the same structure, dtype and shape as `population`. The recombined population. """ sizes = [tf.cast(tf.size(x), dtype=tf.float64) for x in population] seed_stream = tfp_util.SeedStream(seed, salt='binary_crossover') force_crossover_group = distributions.Categorical(sizes).sample( [population_size, 1], seed=seed_stream()) recombinants = [] for i, population_part in enumerate(population): pop_part_flat = tf.reshape(population_part, [population_size, -1]) mutant_part_flat = tf.reshape(mutants[i], [population_size, -1]) part_size = tf.size(population_part) // population_size force_crossovers = tf.one_hot( tf.random.uniform([population_size], minval=0, maxval=part_size, dtype=tf.int32, seed=seed_stream()), part_size, on_value=True, off_value=False, dtype=tf.bool) # Tensor of shape [population_size, size] group_mask = tf.math.equal(force_crossover_group, i) force_crossovers &= group_mask do_binary_crossover = tf.random.uniform( [population_size, part_size], dtype=crossover_prob.dtype.base_dtype, seed=seed_stream()) < crossover_prob do_binary_crossover |= force_crossovers recombinant_flat = tf1.where( do_binary_crossover, mutant_part_flat, pop_part_flat) recombinant = tf.reshape(recombinant_flat, tf.shape(population_part)) recombinants.append(recombinant) return recombinants
def _build_target_quantile_values_op(self): """Build an op used as a target for return values at given quantiles. Returns: An op calculating the target quantile return. """ batch_size = tf.shape(self._replay.rewards)[0] # Calculate SIL modified rewards. replay_action_one_hot = tf.one_hot(self._replay.actions, self.num_actions, 1., 0., name='action_one_hot') replay_target_q = tf.reduce_max(self._replay_target_q_values, axis=1, name='replay_chosen_target_q') replay_target_q_al = tf.reduce_sum(replay_action_one_hot * self._replay_target_q_values, axis=1, name='replay_chosen_target_q_al') comp_value = tf.math.maximum(replay_target_q_al, self._replay.returns) if self._clip > 0.: sil_bonus = self._alpha * tf.clip_by_value( (comp_value - replay_target_q), -self._clip, self._clip) else: sil_bonus = self._alpha * (comp_value - replay_target_q) # Shape of rewards: (num_tau_prime_samples x batch_size) x 1. rewards = (self._replay.rewards + sil_bonus)[:, None] rewards = tf.tile(rewards, [self.num_tau_prime_samples, 1]) is_terminal_multiplier = 1. - tf.cast(self._replay.terminals, tf.float32) # Incorporate terminal state to discount factor. # size of gamma_with_terminal: (num_tau_prime_samples x batch_size) x 1. gamma_with_terminal = self.cumulative_gamma * is_terminal_multiplier gamma_with_terminal = tf.tile(gamma_with_terminal[:, None], [self.num_tau_prime_samples, 1]) # Get the indices of the maximum Q-value across the action dimension. # Shape of replay_next_qt_argmax: (num_tau_prime_samples x batch_size) x 1. replay_next_qt_argmax = tf.tile(self._replay_next_qt_argmax[:, None], [self.num_tau_prime_samples, 1]) # Shape of batch_indices: (num_tau_prime_samples x batch_size) x 1. batch_indices = tf.cast( tf.range(self.num_tau_prime_samples * batch_size)[:, None], tf.int64) # Shape of batch_indexed_target_values: # (num_tau_prime_samples x batch_size) x 2. batch_indexed_target_values = tf.concat( [batch_indices, replay_next_qt_argmax], axis=1) # Shape of next_target_values: (num_tau_prime_samples x batch_size) x 1. target_quantile_values = tf.gather_nd( self._replay_net_target_quantile_values, batch_indexed_target_values)[:, None] return rewards + gamma_with_terminal * target_quantile_values
def sample_chain( num_results, current_state, previous_kernel_results=None, kernel=None, num_burnin_steps=0, num_steps_between_results=0, trace_fn=lambda current_state, kernel_results: kernel_results, return_final_kernel_results=False, parallel_iterations=10, name=None, ): """Implements Markov chain Monte Carlo via repeated `TransitionKernel` steps. This function samples from an Markov chain at `current_state` and whose stationary distribution is governed by the supplied `TransitionKernel` instance (`kernel`). This function can sample from multiple chains, in parallel. (Whether or not there are multiple chains is dictated by the `kernel`.) The `current_state` can be represented as a single `Tensor` or a `list` of `Tensors` which collectively represent the current state. Since MCMC states are correlated, it is sometimes desirable to produce additional intermediate states, and then discard them, ending up with a set of states with decreased autocorrelation. See [Owen (2017)][1]. Such "thinning" is made possible by setting `num_steps_between_results > 0`. The chain then takes `num_steps_between_results` extra steps between the steps that make it into the results. The extra steps are never materialized, and thus do not increase memory requirements. Warning: when setting a `seed` in the `kernel`, ensure that `sample_chain`'s `parallel_iterations=1`, otherwise results will not be reproducible. In addition to returning the chain state, this function supports tracing of auxiliary variables used by the kernel. The traced values are selected by specifying `trace_fn`. By default, all kernel results are traced but in the future the default will be changed to no results being traced, so plan accordingly. See below for some examples of this feature. Args: num_results: Integer number of Markov chain draws. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A `Tensor` or a nested collection of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step of the Markov chain. num_burnin_steps: Integer number of chain steps to take before starting to collect results. Default value: 0 (i.e., no burn-in). num_steps_between_results: Integer number of chain steps between collecting a result. Only one out of every `num_steps_between_samples + 1` steps is included in the returned results. The number of returned chain states is still equal to `num_results`. Default value: 0 (i.e., no thinning). trace_fn: A callable that takes in the current chain state and the previous kernel results and return a `Tensor` or a nested collection of `Tensor`s that is then traced along with the chain state. return_final_kernel_results: If `True`, then the final kernel results are returned alongside the chain state and the trace specified by the `trace_fn`. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., "mcmc_sample_chain"). Returns: checkpointable_states_and_trace: if `return_final_kernel_results` is `True`. The return value is an instance of `CheckpointableStatesAndTrace`. all_states: if `return_final_kernel_results` is `False` and `trace_fn` is `None`. The return value is a `Tensor` or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as input `current_state` but with a prepended `num_results`-size dimension. states_and_trace: if `return_final_kernel_results` is `False` and `trace_fn` is not `None`. The return value is an instance of `StatesAndTrace`. #### Examples ##### Sample from a diagonal-variance Gaussian. I.e., ```none for i=1..n: x[i] ~ MultivariateNormal(loc=0, scale=diag(true_stddev)) # likelihood ``` ```python import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions dims = 10 true_stddev = np.sqrt(np.linspace(1., 3., dims)) likelihood = tfd.MultivariateNormalDiag(loc=0., scale_diag=true_stddev) states = tfp.mcmc.sample_chain( num_results=1000, num_burnin_steps=500, current_state=tf.zeros(dims), kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=likelihood.log_prob, step_size=0.5, num_leapfrog_steps=2), trace_fn=None) sample_mean = tf.reduce_mean(states, axis=0) # ==> approx all zeros sample_stddev = tf.sqrt(tf.reduce_mean( tf.squared_difference(states, sample_mean), axis=0)) # ==> approx equal true_stddev ``` ##### Sampling from factor-analysis posteriors with known factors. I.e., ```none # prior w ~ MultivariateNormal(loc=0, scale=eye(d)) for i=1..n: # likelihood x[i] ~ Normal(loc=w^T F[i], scale=1) ``` where `F` denotes factors. ```python import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions # Specify model. def make_prior(dims): return tfd.MultivariateNormalDiag( loc=tf.zeros(dims)) def make_likelihood(weights, factors): return tfd.MultivariateNormalDiag( loc=tf.matmul(weights, factors, adjoint_b=True)) def joint_log_prob(num_weights, factors, x, w): return (make_prior(num_weights).log_prob(w) + make_likelihood(w, factors).log_prob(x)) def unnormalized_log_posterior(w): # Posterior is proportional to: `p(W, X=x | factors)`. return joint_log_prob(num_weights, factors, x, w) # Setup data. num_weights = 10 # == d num_factors = 40 # == n num_chains = 100 weights = make_prior(num_weights).sample(1) factors = tf.random_normal([num_factors, num_weights]) x = make_likelihood(weights, factors).sample() # Sample from Hamiltonian Monte Carlo Markov Chain. # Get `num_results` samples from `num_chains` independent chains. chains_states, kernels_results = tfp.mcmc.sample_chain( num_results=1000, num_burnin_steps=500, current_state=tf.zeros([num_chains, num_weights], name='init_weights'), kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=unnormalized_log_posterior, step_size=0.1, num_leapfrog_steps=2)) # Compute sample stats. sample_mean = tf.reduce_mean(chains_states, axis=[0, 1]) # ==> approx equal to weights sample_var = tf.reduce_mean( tf.squared_difference(chains_states, sample_mean), axis=[0, 1]) # ==> less than 1 ``` ##### Custom tracing functions. ```python import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions likelihood = tfd.Normal(loc=0., scale=1.) def sample_chain(trace_fn): return tfp.mcmc.sample_chain( num_results=1000, num_burnin_steps=500, current_state=0., kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=likelihood.log_prob, step_size=0.5, num_leapfrog_steps=2), trace_fn=trace_fn) def trace_log_accept_ratio(states, previous_kernel_results): return previous_kernel_results.log_accept_ratio def trace_everything(states, previous_kernel_results): return previous_kernel_results _, log_accept_ratio = sample_chain(trace_fn=trace_log_accept_ratio) _, kernel_results = sample_chain(trace_fn=trace_everything) acceptance_prob = tf.math.exp(tf.minimum(log_accept_ratio, 0.)) # Equivalent to, but more efficient than: acceptance_prob = tf.math.exp(tf.minimum( kernel_results.log_accept_ratio, 0.)) ``` #### References [1]: Art B. Owen. Statistically efficient thinning of a Markov chain sampler. _Technical Report_, 2017. http://statweb.stanford.edu/~owen/reports/bestthinning.pdf """ if not kernel.is_calibrated: warnings.warn( "supplied `TransitionKernel` is not calibrated. Markov " "chain may not converge to intended target distribution.") with tf.name_scope(name or "mcmc_sample_chain"): num_results = tf.convert_to_tensor(num_results, dtype=tf.int32, name="num_results") num_burnin_steps = tf.convert_to_tensor(num_burnin_steps, dtype=tf.int32, name="num_burnin_steps") num_steps_between_results = tf.convert_to_tensor( num_steps_between_results, dtype=tf.int32, name="num_steps_between_results") current_state = tf.nest.map_structure( lambda x: tf.convert_to_tensor(x, name="current_state"), current_state) if previous_kernel_results is None: previous_kernel_results = kernel.bootstrap_results(current_state) if trace_fn is None: # It simplifies the logic to use a dummy function here. trace_fn = lambda *args: () no_trace = True else: no_trace = False if trace_fn is sample_chain.__defaults__[4]: warnings.warn( "Tracing all kernel results by default is deprecated. Set " "the `trace_fn` argument to None (the future default " "value) or an explicit callback that traces the values " "you are interested in.") def _trace_scan_fn(state_and_results, num_steps): next_state, current_kernel_results = mcmc_util.smart_for_loop( loop_num_iter=num_steps, body_fn=kernel.one_step, initial_loop_vars=list(state_and_results), parallel_iterations=parallel_iterations) return next_state, current_kernel_results (_, final_kernel_results), (all_states, trace) = mcmc_util.trace_scan( loop_fn=_trace_scan_fn, initial_state=(current_state, previous_kernel_results), elems=tf.one_hot(indices=0, depth=num_results, on_value=1 + num_burnin_steps, off_value=1 + num_steps_between_results, dtype=tf.int32), # pylint: disable=g-long-lambda trace_fn=lambda state_and_results: (state_and_results[0], trace_fn(*state_and_results)), # pylint: enable=g-long-lambda parallel_iterations=parallel_iterations) if return_final_kernel_results: return CheckpointableStatesAndTrace( all_states=all_states, trace=trace, final_kernel_results=final_kernel_results) else: if no_trace: return all_states else: return StatesAndTrace(all_states=all_states, trace=trace)
def prepare_dataset(self, dataset: dataset_lib.OffpolicyDataset, target_policy: tf_policy.TFPolicy): """Performs pre-computations on dataset to make solving easier.""" episodes, valid_steps = dataset.get_all_episodes( limit=self._limit_episodes) total_num_steps_per_episode = tf.shape(valid_steps)[1] - 1 num_episodes = tf.shape(valid_steps)[0] num_samples = num_episodes * total_num_steps_per_episode valid_and_not_last = tf.logical_and(valid_steps, episodes.discount > 0) valid_indices = tf.squeeze( tf.where(tf.reshape(valid_and_not_last[:, :-1], [-1]))) # Flatten all tensors so that each data sample is a tuple of # (initial_env_step, env_step, next_env_step). initial_env_step = tf.nest.map_structure( lambda t: tf.squeeze( tf.reshape( tf.repeat(t[:, 0:1, ...], axis=1, repeats=total_num_steps_per_episode), [num_samples, -1])), episodes) initial_env_step = tf.nest.map_structure( lambda t: tf.gather(t, valid_indices), initial_env_step) tfagents_initial_env_step = dataset_lib.convert_to_tfagents_timestep( initial_env_step) env_step = tf.nest.map_structure( lambda t: tf.squeeze( tf.reshape(t[:, 0:total_num_steps_per_episode, ...], [num_samples, -1])), episodes) env_step = tf.nest.map_structure(lambda t: tf.gather(t, valid_indices), env_step) tfagents_env_step = dataset_lib.convert_to_tfagents_timestep(env_step) next_env_step = tf.nest.map_structure( lambda t: tf.squeeze( tf.reshape(t[:, 1:total_num_steps_per_episode + 1, ...], [num_samples, -1])), episodes) next_env_step = tf.nest.map_structure( lambda t: tf.gather(t, valid_indices), next_env_step) tfagents_next_env_step = dataset_lib.convert_to_tfagents_timestep( next_env_step) # Get target probabilities for initial and next steps. initial_target_probs = target_policy.distribution( tfagents_initial_env_step).action.probs_parameter() next_target_probs = target_policy.distribution( tfagents_next_env_step).action.probs_parameter() # Map states and actions to indices into tabular representation. initial_states = tf.tile( tf.reshape(initial_env_step.observation, [-1, 1]), [1, self._num_actions]) initial_actions = tf.tile( tf.reshape(tf.range(self._num_actions), [1, -1]), [initial_env_step.observation.shape[0], 1]) initial_nu_indices = self._get_index(initial_states, initial_actions) next_states = tf.tile(tf.reshape(next_env_step.observation, [-1, 1]), [1, self._num_actions]) next_actions = tf.tile( tf.reshape(tf.range(self._num_actions), [1, -1]), [next_env_step.observation.shape[0], 1]) next_nu_indices = self._get_index(next_states, next_actions) next_nu_indices = tf.where( tf.expand_dims(next_env_step.is_absorbing(), -1), -1 * tf.ones_like(next_nu_indices), next_nu_indices) nu_indices = self._get_index(env_step.observation, env_step.action) target_log_probabilities = target_policy.distribution( tfagents_env_step).action.log_prob(env_step.action) if not self._solve_for_state_action_ratio: policy_ratio = tf.exp(target_log_probabilities - env_step.get_log_probability()) else: policy_ratio = tf.ones([ target_log_probabilities.shape[0], ]) policy_ratios = tf.tile(tf.reshape(policy_ratio, [-1, 1]), [1, self._num_actions]) # Bellman residual matrix of size [n_data, n_dim]. a_vec = tf.one_hot(nu_indices, self._dimension) - tf.reduce_sum( self._gamma * tf.expand_dims(next_target_probs * policy_ratios, axis=-1) * tf.one_hot(next_nu_indices, self._dimension), axis=1) state_action_count = self._get_state_action_counts(env_step) # Bellman residual matrix of size [n_dim, n_dim]. td_mat = tf.einsum('ai, a, aj -> ij', tf.one_hot(nu_indices, self._dimension), 1.0 / tf.cast(state_action_count, tf.float32), a_vec) # Reward vector of size [n_data]. weighted_rewards = policy_ratio * self._reward_fn(env_step) # Reward vector of size [n_dim]. bias = tf.reduce_sum(tf.one_hot(nu_indices, self._dimension) * tf.reshape(weighted_rewards, [-1, 1]) * 1.0 / tf.cast(state_action_count, tf.float32)[:, None], axis=0) # Initialize. self._nu = np.ones_like(self._nu) * bias[:, None] self._nu2 = np.ones_like(self._nu2) * bias[:, None] self._a_vec = a_vec self._td_mat = td_mat self._bias = bias self._weighted_rewards = weighted_rewards self._state_action_count = state_action_count self._nu_indices = nu_indices self._initial_nu_indices = initial_nu_indices self._initial_target_probs = initial_target_probs
def _sample_n(self, n, seed=None): if self._use_static_graph: with tf.control_dependencies(self._assertions): # This sampling approach is almost the same as the approach used by # `MixtureSameFamily`. The differences are due to having a list of # `Distribution` objects rather than a single object, and maintaining # random seed management that is consistent with the non-static code # path. samples = [] cat_samples = self.cat.sample(n, seed=seed) stream = seed_stream.SeedStream(seed, salt="Mixture") for c in range(self.num_components): samples.append(self.components[c].sample(n, seed=stream())) x = tf.stack(samples, -self._static_event_shape.ndims - 1) # [n, B, k, E] npdt = x.dtype.as_numpy_dtype mask = tf.one_hot( indices=cat_samples, # [n, B] depth=self._num_components, # == k on_value=np.ones([], dtype=npdt), off_value=np.zeros([], dtype=npdt)) # [n, B, k] mask = distribution_util.pad_mixture_dimensions( mask, self, self._cat, self._static_event_shape.ndims) # [n, B, k, [1]*e] return tf.reduce_sum( input_tensor=x * mask, axis=-1 - self._static_event_shape.ndims) # [n, B, E] with tf.control_dependencies(self._assertions): n = tf.convert_to_tensor(value=n, name="n") static_n = tf.get_static_value(n) n = int(static_n) if static_n is not None else n cat_samples = self.cat.sample(n, seed=seed) static_samples_shape = cat_samples.shape if static_samples_shape.is_fully_defined(): samples_shape = static_samples_shape.as_list() samples_size = static_samples_shape.num_elements() else: samples_shape = tf.shape(input=cat_samples) samples_size = tf.size(input=cat_samples) static_batch_shape = self.batch_shape if static_batch_shape.is_fully_defined(): batch_shape = static_batch_shape.as_list() batch_size = static_batch_shape.num_elements() else: batch_shape = self.batch_shape_tensor() batch_size = tf.reduce_prod(input_tensor=batch_shape) static_event_shape = self.event_shape if static_event_shape.is_fully_defined(): event_shape = np.array(static_event_shape.as_list(), dtype=np.int32) else: event_shape = self.event_shape_tensor() # Get indices into the raw cat sampling tensor. We will # need these to stitch sample values back out after sampling # within the component partitions. samples_raw_indices = tf.reshape(tf.range(0, samples_size), samples_shape) # Partition the raw indices so that we can use # dynamic_stitch later to reconstruct the samples from the # known partitions. partitioned_samples_indices = tf.dynamic_partition( data=samples_raw_indices, partitions=cat_samples, num_partitions=self.num_components) # Copy the batch indices n times, as we will need to know # these to pull out the appropriate rows within the # component partitions. batch_raw_indices = tf.reshape( tf.tile(tf.range(0, batch_size), [n]), samples_shape) # Explanation of the dynamic partitioning below: # batch indices are i.e., [0, 1, 0, 1, 0, 1] # Suppose partitions are: # [1 1 0 0 1 1] # After partitioning, batch indices are cut as: # [batch_indices[x] for x in 2, 3] # [batch_indices[x] for x in 0, 1, 4, 5] # i.e. # [1 1] and [0 0 0 0] # Now we sample n=2 from part 0 and n=4 from part 1. # For part 0 we want samples from batch entries 1, 1 (samples 0, 1), # and for part 1 we want samples from batch entries 0, 0, 0, 0 # (samples 0, 1, 2, 3). partitioned_batch_indices = tf.dynamic_partition( data=batch_raw_indices, partitions=cat_samples, num_partitions=self.num_components) samples_class = [None for _ in range(self.num_components)] stream = seed_stream.SeedStream(seed, salt="Mixture") for c in range(self.num_components): n_class = tf.size(input=partitioned_samples_indices[c]) samples_class_c = self.components[c].sample(n_class, seed=stream()) # Pull out the correct batch entries from each index. # To do this, we may have to flatten the batch shape. # For sample s, batch element b of component c, we get the # partitioned batch indices from # partitioned_batch_indices[c]; and shift each element by # the sample index. The final lookup can be thought of as # a matrix gather along locations (s, b) in # samples_class_c where the n_class rows correspond to # samples within this component and the batch_size columns # correspond to batch elements within the component. # # Thus the lookup index is # lookup[c, i] = batch_size * s[i] + b[c, i] # for i = 0 ... n_class[c] - 1. lookup_partitioned_batch_indices = ( batch_size * tf.range(n_class) + partitioned_batch_indices[c]) samples_class_c = tf.reshape( samples_class_c, tf.concat([[n_class * batch_size], event_shape], 0)) samples_class_c = tf.gather(samples_class_c, lookup_partitioned_batch_indices, name="samples_class_c_gather") samples_class[c] = samples_class_c # Stitch back together the samples across the components. lhs_flat_ret = tf.dynamic_stitch( indices=partitioned_samples_indices, data=samples_class) # Reshape back to proper sample, batch, and event shape. ret = tf.reshape( lhs_flat_ret, tf.concat( [samples_shape, self.event_shape_tensor()], 0)) ret.set_shape( tf.TensorShape(static_samples_shape).concatenate( self.event_shape)) return ret
def _write_update_to_result(): one_hot = tf.one_hot(ind, depth=size_along_axis) mask_shape = len(tensor.shape) * [1] mask_shape[axis] = size_along_axis mask = tf.reshape(one_hot > 0, mask_shape) return tf.where(mask, new_tensor, tensor)
labels_one_hot=tf.constant(y_train_onehot), samples=tf.constant(samples), weights=tf.constant(weights), _lambda=tf.constant(_lambda)) loss1, share_loss1, tau = loss1.numpy(), share_loss1.numpy( ), tau.numpy() save_tau.append(tau) save_loss1.append(loss1) save_share_loss1.append(share_loss1) black_box_probs = black_box(X_train, trainable=tf.constant(False)) black_box_labels = np.argmax(black_box_probs.numpy(), axis=1) if not (config_params["weights"]): black_box_probs = tf.one_hot(black_box_labels, len(np.unique(y_train))) #Learning sTGMA #tf.print("----Begin----") for j in range(hyper_params["stgma_steps"]): #print("Iteration sTGMA: ", j) toc = time() #print(X_train.shape,black_box_labels.astype(np.int32).shape, responsibilities.shape, samples.shape, black_box_probs.shape, tf.random.shuffle(tf.range(model.data_dim), seed=(2^step)*(2*j+1) )) loss, share_loss2 = train_step_sTGMA( data=tf.constant(X_train), labels=tf.constant(black_box_labels.astype(np.int32)), responsibilities=tf.constant(responsibilities), eta=tf.constant(eta, dtype=tf.float32), samples=tf.constant(samples), weights=tf.constant(black_box_probs), t_range=tf.random.shuffle(tf.range(model.data_dim),
def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs if FLAGS.version2 and FLAGS.ensemble_size > 1: images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) if not (FLAGS.member_sampling or FLAGS.expected_probs): labels = tf.tile(labels, [FLAGS.ensemble_size]) if FLAGS.num_train_samples > 1: images = tf.tile(images, [FLAGS.num_train_samples, 1, 1, 1]) with tf.GradientTape() as tape: logits = model(images, training=True) probs = tf.nn.softmax(logits) # Diversity evaluation. if FLAGS.version2 and FLAGS.ensemble_size > 1: per_probs = tf.reshape( probs, tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0)) diversity_results = ed.metrics.average_pairwise_diversity( per_probs, FLAGS.ensemble_size) if FLAGS.num_train_samples > 1: probs = tf.reshape( probs, tf.concat( [[FLAGS.num_train_samples, -1], probs.shape[1:]], 0)) probs = tf.reduce_mean(probs, 0) if FLAGS.member_sampling and FLAGS.version2 and FLAGS.ensemble_size > 1: idx = tf.random.uniform([], maxval=FLAGS.ensemble_size, dtype=tf.int64) idx_one_hot = tf.expand_dims( tf.one_hot(idx, FLAGS.ensemble_size, dtype=probs.dtype), 0) probs_shape = probs.shape probs = tf.reshape(probs, [FLAGS.ensemble_size, -1]) probs = tf.matmul(idx_one_hot, probs) probs = tf.reshape(probs, tf.concat([[-1], probs_shape[1:]], 0)) elif FLAGS.expected_probs and FLAGS.version2 and FLAGS.ensemble_size > 1: probs = tf.reshape( probs, tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0)) probs = tf.reduce_mean(probs, 0) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels, probs)) filtered_variables = [] for var in model.trainable_variables: # Apply l2 on the slow weights and bias terms. This excludes BN # parameters and fast weight approximate posterior/prior parameters, # but pay caution to their naming scheme. if 'kernel' in var.name or 'bias' in var.name: filtered_variables.append(tf.reshape(var, (-1, ))) l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) kl = sum(model.losses) / train_dataset_size kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype) kl_scale /= FLAGS.kl_annealing_steps kl_scale = tf.minimum(1., kl_scale) kl_loss = kl_scale * kl # Scale the loss given the TPUStrategy will reduce sum all gradients. loss = negative_log_likelihood + l2_loss + kl_loss scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) # Separate learning rate implementation. grad_list = [] if FLAGS.fast_weight_lr_multiplier != 1.0: grads_and_vars = list(zip(grads, model.trainable_variables)) for vec, var in grads_and_vars: # Apply different learning rate on the fast weight approximate # posterior/prior parameters. This is excludes BN and slow weights, # but pay caution to the naming scheme. if ('batch_norm' not in var.name and 'kernel' not in var.name): grad_list.append( (vec * FLAGS.fast_weight_lr_multiplier, var)) else: grad_list.append((vec, var)) optimizer.apply_gradients(grad_list) else: optimizer.apply_gradients(zip(grads, model.trainable_variables)) metrics['train/ece'].update_state(labels, probs) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, probs) if FLAGS.version2 and FLAGS.ensemble_size > 1: for k, v in diversity_results.items(): training_diversity['train/' + k].update_state(v)
def _mode(self): ret = tf.argmax(input=self.logits, axis=self._batch_rank) ret = tf.one_hot(ret, self.event_size, dtype=self.dtype) tensorshape_util.set_shape(ret, self.logits.shape) return ret
def softquantiles(x, quantiles, quantile_width=None, axis=-1, may_squeeze=True, **kwargs): """Computes soft quantiles via optimal transport. This operator takes advantage of the fact that an exhaustive softsort is not required to recover a single quantile. Instead, one can transport all input values in x onto only 3 weighted values. Target weights are adjusted so that those values in x that are transported to the middle value in the target vector y correspond to those concentrating around the quantile of interest. This idea generalizes to more quantiles, interleaving small weights on the quantile indices and bigger weights in between, corresponding to the gap from one desired quantile to the next one. Args: x: Tensor<float> of any shape. quantiles: list<float> the quantiles to be returned. It can also be a single float. quantile_width: (float) mass given to the bucket supposed to attract points whose value concentrate around the desired quantile value. Bigger width means that we allow the soft quantile to be a mixture of more points further away from the quantile. If None, the width is set at 1/n where n is the number of values considered (the size along the 'axis'). axis: (int) the axis along which to compute the quantile. may_squeeze: (bool) should we squeeze the output tensor in case of a single quantile. **kwargs: see SoftQuantilizer for possible extra parameters. Returns: A Tensor<float> similar to the input tensor, but the axis dimension is replaced by the number of quantiles specified in the quantiles list. Hence, if only a quantile is requested (quantiles is a float) only one value in that axis is returned. When several quantiles are requested, the tensor will have that many values in that axis. Raises: tf.errors.InvalidArgumentError when the quantiles and quantile width are not correct, namely quantiles are either not in sorted order or the quantile_width is too large. """ if isinstance(quantiles, float): quantiles = [quantiles] quantiles = tf.constant(quantiles, tf.float32) # Preprocesses submitted quantiles to check that they satisfy elementary # constraints. valid_quantiles = tf.boolean_mask( quantiles, tf.logical_and(quantiles > 0.0, quantiles < 1.0)) num_quantiles = tf.shape(valid_quantiles)[0] # Includes values on both ends of [0,1]. extended_quantiles = tf.concat([[0.0], valid_quantiles, [1.0]], axis=0) # Builds filler_weights in between the target quantiles. filler_weights = extended_quantiles[1:] - extended_quantiles[:-1] if quantile_width is None: quantile_width = tf.reduce_min( tf.concat([ filler_weights, [1.0 / tf.cast(tf.shape(x)[axis], dtype=x.dtype)] ], axis=0)) # Takes into account quantile_width in the definition of weights shift = -tf.ones(tf.shape(filler_weights), dtype=x.dtype) shift = shift + 0.5 * (tf.one_hot(0, num_quantiles + 1) + tf.one_hot(num_quantiles, num_quantiles + 1)) filler_weights = filler_weights + quantile_width * shift assert_op = tf.Assert(tf.reduce_all(filler_weights >= 0.0), [filler_weights]) with tf.control_dependencies([assert_op]): # Adds one more value to have tensors of the same shape to interleave them. quantile_weights = tf.ones(num_quantiles + 1) * quantile_width # Interleaves the filler_weights with the quantile weights. weights = tf.reshape( tf.stack([filler_weights, quantile_weights], axis=1), (-1, ))[:-1] # Sends only the positive weights to the softsort operator. positive_weights = tf.boolean_mask(weights, weights > 0.0) all_quantiles = softsort(x, direction='ASCENDING', axis=axis, target_weights=positive_weights, **kwargs) # Recovers the indices corresponding to the desired quantiles. odds = tf.math.floormod(tf.range(weights.shape[0], dtype=tf.float32), 2) positives = tf.cast(weights > 0.0, tf.float32) indices = tf.cast(tf.math.cumsum(positives) * odds, dtype=tf.int32) indices = tf.boolean_mask(indices, indices > 0) - 1 result = tf.gather(all_quantiles, indices, axis=axis) # In the specific case where we want a single quantile, squeezes the # quantile dimension. can_squeeze = tf.equal(tf.shape(result)[axis], 1) if tf.math.logical_and(can_squeeze, may_squeeze): result = tf.squeeze(result, axis=axis) return result
def _sample_n(self, n, seed=None): seeds = samplers.split_seed(seed, n=self.num_components + 1, salt='Mixture') try: seed_stream = SeedStream(seed, salt='Mixture') except TypeError as e: # Can happen for Tensor seed. seed_stream = None seed_stream_err = e if self._use_static_graph: # This sampling approach is almost the same as the approach used by # `MixtureSameFamily`. The differences are due to having a list of # `Distribution` objects rather than a single object, and maintaining # random seed management that is consistent with the non-static code # path. samples = [] cat_samples = self.cat.sample(n, seed=seeds[0]) for c in range(self.num_components): try: samples.append(self.components[c].sample(n, seed=seeds[c + 1])) if seed_stream is not None: seed_stream() except TypeError as e: if ('Expected int for argument' not in str(e) and TENSOR_SEED_MSG_PREFIX not in str(e)): raise if seed_stream is None: raise seed_stream_err msg = ( 'Falling back to stateful sampling for `components[{}]` {} of ' 'type `{}`. Please update to use `tf.random.stateless_*` RNGs. ' 'This fallback may be removed after 20-Aug-2020. ({})') warnings.warn( msg.format(c, self.components[c].name, type(self.components[c]), str(e))) samples.append(self.components[c].sample( n, seed=seed_stream())) stack_axis = -1 - tensorshape_util.rank(self._static_event_shape) x = tf.stack(samples, axis=stack_axis) # [n, B, k, E] npdt = dtype_util.as_numpy_dtype(x.dtype) mask = tf.one_hot( indices=cat_samples, # [n, B] depth=self._num_components, # == k on_value=npdt(1), off_value=npdt(0)) # [n, B, k] mask = distribution_util.pad_mixture_dimensions( mask, self, self._cat, tensorshape_util.rank( self._static_event_shape)) # [n, B, k, [1]*e] return tf.reduce_sum(x * mask, axis=stack_axis) # [n, B, E] n = tf.convert_to_tensor(n, name='n') static_n = tf.get_static_value(n) n = int(static_n) if static_n is not None else n cat_samples = self.cat.sample(n, seed=seeds[0]) static_samples_shape = cat_samples.shape if tensorshape_util.is_fully_defined(static_samples_shape): samples_shape = tensorshape_util.as_list(static_samples_shape) samples_size = tensorshape_util.num_elements(static_samples_shape) else: samples_shape = tf.shape(cat_samples) samples_size = tf.size(cat_samples) static_batch_shape = self.batch_shape if tensorshape_util.is_fully_defined(static_batch_shape): batch_shape = tensorshape_util.as_list(static_batch_shape) batch_size = tensorshape_util.num_elements(static_batch_shape) else: batch_shape = tf.shape(cat_samples)[1:] batch_size = tf.reduce_prod(batch_shape) static_event_shape = self.event_shape if tensorshape_util.is_fully_defined(static_event_shape): event_shape = np.array( tensorshape_util.as_list(static_event_shape), dtype=np.int32) else: event_shape = None # Get indices into the raw cat sampling tensor. We will # need these to stitch sample values back out after sampling # within the component partitions. samples_raw_indices = tf.reshape(tf.range(0, samples_size), samples_shape) # Partition the raw indices so that we can use # dynamic_stitch later to reconstruct the samples from the # known partitions. partitioned_samples_indices = tf.dynamic_partition( data=samples_raw_indices, partitions=cat_samples, num_partitions=self.num_components) # Copy the batch indices n times, as we will need to know # these to pull out the appropriate rows within the # component partitions. batch_raw_indices = tf.reshape(tf.tile(tf.range(0, batch_size), [n]), samples_shape) # Explanation of the dynamic partitioning below: # batch indices are i.e., [0, 1, 0, 1, 0, 1] # Suppose partitions are: # [1 1 0 0 1 1] # After partitioning, batch indices are cut as: # [batch_indices[x] for x in 2, 3] # [batch_indices[x] for x in 0, 1, 4, 5] # i.e. # [1 1] and [0 0 0 0] # Now we sample n=2 from part 0 and n=4 from part 1. # For part 0 we want samples from batch entries 1, 1 (samples 0, 1), # and for part 1 we want samples from batch entries 0, 0, 0, 0 # (samples 0, 1, 2, 3). partitioned_batch_indices = tf.dynamic_partition( data=batch_raw_indices, partitions=cat_samples, num_partitions=self.num_components) samples_class = [None for _ in range(self.num_components)] for c in range(self.num_components): n_class = tf.size(partitioned_samples_indices[c]) try: samples_class_c = self.components[c].sample(n_class, seed=seeds[c + 1]) if seed_stream is not None: seed_stream() except TypeError as e: if ('Expected int for argument' not in str(e) and TENSOR_SEED_MSG_PREFIX not in str(e)): raise if seed_stream is None: raise seed_stream_err msg = ( 'Falling back to stateful sampling for `components[{}]` {} of ' 'type `{}`. Please update to use `tf.random.stateless_*` RNGs. ' 'This fallback may be removed after 20-Aug-2020. ({})') warnings.warn( msg.format(c, self.components[c].name, type(self.components[c]), str(e))) samples_class_c = self.components[c].sample(n_class, seed=seed_stream()) if event_shape is None: batch_ndims = prefer_static.rank_from_shape(batch_shape) event_shape = tf.shape(samples_class_c)[1 + batch_ndims:] # Pull out the correct batch entries from each index. # To do this, we may have to flatten the batch shape. # For sample s, batch element b of component c, we get the # partitioned batch indices from # partitioned_batch_indices[c]; and shift each element by # the sample index. The final lookup can be thought of as # a matrix gather along locations (s, b) in # samples_class_c where the n_class rows correspond to # samples within this component and the batch_size columns # correspond to batch elements within the component. # # Thus the lookup index is # lookup[c, i] = batch_size * s[i] + b[c, i] # for i = 0 ... n_class[c] - 1. lookup_partitioned_batch_indices = ( batch_size * tf.range(n_class) + partitioned_batch_indices[c]) samples_class_c = tf.reshape( samples_class_c, tf.concat([[n_class * batch_size], event_shape], 0)) samples_class_c = tf.gather(samples_class_c, lookup_partitioned_batch_indices, name='samples_class_c_gather') samples_class[c] = samples_class_c # Stitch back together the samples across the components. lhs_flat_ret = tf.dynamic_stitch(indices=partitioned_samples_indices, data=samples_class) # Reshape back to proper sample, batch, and event shape. ret = tf.reshape(lhs_flat_ret, tf.concat([samples_shape, event_shape], 0)) tensorshape_util.set_shape( ret, tensorshape_util.concatenate(static_samples_shape, self.event_shape)) return ret
def reduce_audio_in_batch(tensor, hparams=None, is_training=True): instrument_count = hparams.timbre_training_max_instruments note_croppping_list = [] instrument_family_list = [] samples_list = [] max_length = 0 for i in range(instrument_count): pitch = tensor['pitch'][i] # Move the audio so there are different attack times. start_idx = tf.random.uniform((), minval=0, maxval=hparams.timbre_max_start_offset, dtype='int64') samples = K.concatenate( [tf.zeros(start_idx), tf.sparse.to_dense(tensor['audio'])[i]]) end_idx = ( start_idx + tf.py_function(_get_approx_note_length, [tf.sparse.to_dense(tensor['audio'])[i]], tf.int64)) if hparams.timbre_max_len and end_idx > hparams.timbre_max_len: samples = tf.slice(samples, begin=[0], size=[hparams.timbre_max_len]) end_idx = hparams.timbre_max_len if len(samples) > max_length: max_length = len(samples) samples_list.append(samples) instrument_family = tensor['instrument_family'][i] note_croppping_list.append( timbre_dataset_util.NoteCropping(pitch=pitch, start_idx=start_idx, end_idx=end_idx)) instrument_family_list.append( tf.one_hot(tf.cast(instrument_family, tf.int32), hparams.timbre_num_classes)) # Pad the end of the shorter audio clips. samples_list = list( map(lambda x: tf.pad(x, [[0, max_length - len(x)]]), samples_list)) combined_samples = ( tf.reduce_sum(tf.convert_to_tensor(samples_list), axis=0) / instrument_count) # Ensure all audios in batches are the same length. if hparams.timbre_max_len: pad_length = hparams.timbre_max_len else: pad_length = hparams.timbre_max_start_offset + 5 * hparams.sample_rate combined_samples = tf.pad( combined_samples, [[0, pad_length - tf.shape(combined_samples)[0]]]) note_croppings = tf.convert_to_tensor(note_croppping_list, dtype=tf.int32) instrument_families = tf.convert_to_tensor(instrument_family_list, dtype=tf.int32) wav_data = tf.py_function( lambda x: audio_io.samples_to_wav_data( x.numpy(), sample_rate=hparams.sample_rate), [combined_samples], tf.string) return dict( audio=wav_data, note_croppings=note_croppings, instrument_families=instrument_families, )
def interpolate(x_values, spline_data, optimize_for_tpu=False, dtype=None, name=None): """Interpolates spline values for the given `x_values` and the `spline_data`. Constant extrapolation is performed for the values outside the domain `spline_data.x_data`. This means that for `x > max(spline_data.x_data)`, `interpolate(x, spline_data) = spline_data.y_data[-1]` and for `x < min(spline_data.x_data)`, `interpolate(x, spline_data) = spline_data.y_data[0]`. For the interpolation formula refer to p.548 of [1]. #### References: [1]: R. Sedgewick, Algorithms in C, 1990, p. 545-550. Link: http://index-of.co.uk/Algorithms/Algorithms%20in%20C.pdf Args: x_values: A real `Tensor` of shape `batch_shape + [num_points]`. spline_data: An instance of `SplineParameters`. `spline_data.x_data` should have the same batch shape as `x_values`. optimize_for_tpu: A Python bool. If `True`, the algorithm uses one-hot encoding to lookup indices of `x_values` in `spline_data.x_data`. This significantly improves performance of the algorithm on a TPU device but may slow down performance on the CPU. Default value: `False`. dtype: Optional dtype for `x_values`. Default value: `None` which maps to the default dtype inferred by TensorFlow. name: Python `str` name prefixed to ops created by this function. Default value: `None` which is mapped to the default name `cubic_spline_interpolate`. Returns: A `Tensor` of the same shape and `dtype` as `x_values`. Represents the interpolated values. Raises: ValueError: If `x_values` batch shape is different from `spline_data.x_data` batch shape. """ name = name or "cubic_spline_interpolate" with tf.name_scope(name): x_values = tf.convert_to_tensor(x_values, dtype=dtype, name="x_values") dtype = x_values.dtype # Unpack the spline data x_data = spline_data.x_data y_data = spline_data.y_data spline_coeffs = spline_data.spline_coeffs rank = max(x_data.shape.rank, x_values.shape.rank) x_data = _expand_to_rank(x_data, rank) y_data = _expand_to_rank(y_data, rank) x_values = _expand_to_rank(x_values, rank) spline_coeffs = _expand_to_rank(spline_coeffs, rank) # Try broadcast batch_shapes if x_values.shape.as_list()[:-1] != x_data.shape.as_list()[:-1]: try: x_values = _broadcast_batch_shape(x_values, x_data.shape[:-1]) except (tf.errors.InvalidArgumentError, ValueError): try: x_data = _broadcast_batch_shape(x_data, x_values.shape[:-1]) y_data = _broadcast_batch_shape(y_data, x_values.shape[:-1]) spline_coeffs = _broadcast_batch_shape( spline_coeffs, x_values.shape[:-1]) except (tf.errors.InvalidArgumentError, ValueError): msg = ("Can not broadcast batch shapes {} and {}") raise ValueError( msg.format(x_values.shape.as_list()[:-1], x_data.shape.as_list()[:-1])) # Determine the splines to use. indices = tf.searchsorted(x_data, x_values, side="right") - 1 # This selects all elements for the start of the spline interval. # Make sure indices lie in the permissible range indices_lower = tf.maximum(indices, 0) # This selects all elements for the end of the spline interval. # Make sure indices lie in the permissible range indices_upper = tf.minimum(indices + 1, x_data.shape.as_list()[-1] - 1) # Prepare indices for `tf.gather_nd` or `tf.one_hot` # TODO(b/156720909): Extract get_slice logic into a common utilities module # for cubic and linear interpolation if optimize_for_tpu: x_data_size = x_data.shape.as_list()[-1] lower_encoding = tf.one_hot(indices_lower, x_data_size, dtype=dtype) upper_encoding = tf.one_hot(indices_upper, x_data_size, dtype=dtype) else: index_matrix = _prepare_indices(indices) lower_encoding = tf.concat( [index_matrix, tf.expand_dims(indices_lower, -1)], -1) upper_encoding = tf.concat( [index_matrix, tf.expand_dims(indices_upper, -1)], -1) # Calculate dx and dy. # Simplified logic: # dx = x_data[indices + 1] - x_data[indices] # dy = y_data[indices + 1] - y_data[indices] # indices is a tensor with different values per row/spline # Hence use a selection matrix with gather_nd def get_slice(x, encoding): if optimize_for_tpu: return tf.math.reduce_sum(tf.expand_dims(x, axis=-2) * encoding, axis=-1) else: return tf.gather_nd(x, encoding) x0 = get_slice(x_data, lower_encoding) x1 = get_slice(x_data, upper_encoding) dx = x1 - x0 y0 = get_slice(y_data, lower_encoding) y1 = get_slice(y_data, upper_encoding) dy = y1 - y0 spline_coeffs0 = get_slice(spline_coeffs, lower_encoding) spline_coeffs1 = get_slice(spline_coeffs, upper_encoding) t = (x_values - x0) / dx t = tf.where(dx > 0, t, tf.zeros_like(t)) df = ((t + 1.0) * spline_coeffs1 * 2.0) - ( (t - 2.0) * spline_coeffs0 * 2.0) df1 = df * t * (t - 1) / 6.0 result = y0 + (t * dy) + (dx * dx * df1) # Use constant extrapolation outside the domain upper_bound = tf.expand_dims(tf.reduce_max(x_data, -1), -1) + tf.zeros_like(result) lower_bound = tf.expand_dims(tf.reduce_min(x_data, -1), -1) + tf.zeros_like(result) result = tf.where( tf.logical_and(x_values <= upper_bound, x_values >= lower_bound), result, tf.where(x_values > upper_bound, y0, y1)) return result
def convert_to_one_hot(self, samples): return tf.one_hot( tf.argmax(samples, axis=-1), self.distribution.event_size, dtype=self._output_dtype)
def map_fn(image, label): image = preprocess_fn_finetune(image) label = tf.one_hot(label, num_classes) return image, label