def sample_categorical(x, dim=None): dim = x.shape[-1] if dim is None else dim cdf = mtf.cumsum(x, dim) rand_uniform = mtf.random_uniform(x.mesh, x.shape - dim, minval=0, maxval=1) mask = mtf.cast(mtf.greater(cdf, rand_uniform), tf.int32) return mtf.argmax(mask, dim)
def _noisy_targets(self, targets, losses=None): """Generate noisy targets for denoising models. Args: targets: a Tensor losses: an optional list onto which to append traning losses Returns: a Tensor the same dtype and shape as Targets """ hparams = self._hparams if hparams.mode == tf.estimator.ModeKeys.TRAIN: nt_train = self._noisy_targets_from_spec( targets, hparams.noising_spec_train, losses=losses) if hparams.noising_use_eval_during_train > 0: nt_eval = self._noisy_targets_from_spec( targets, hparams.noising_spec_eval) use_eval_noising = mtf.less( mtf.random_uniform(targets.mesh, targets.shape - self.length_dim), hparams.noising_use_eval_during_train) nt_train = mtf.where(use_eval_noising, nt_eval, nt_train) return nt_train else: return self._noisy_targets_from_spec(targets, hparams.noising_spec_eval)
def benchmark_model(mesh): """ Initializes a 3D volume with random noise, and execute a forward FFT """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) x_dim = mtf.Dimension("nx", FLAGS.cube_size) y_dim = mtf.Dimension("ny", FLAGS.cube_size) z_dim = mtf.Dimension("nz", FLAGS.cube_size) tx_dim = mtf.Dimension("tnx", FLAGS.cube_size) ty_dim = mtf.Dimension("tny", FLAGS.cube_size) tz_dim = mtf.Dimension("tnz", FLAGS.cube_size) # Create field field = mtf.random_uniform(mesh, [batch_dim, x_dim, y_dim, z_dim]) # Apply FFT fft_field = mpm.fft3d(mtf.cast(field, tf.complex64), [tx_dim, ty_dim, tz_dim]) # Inverse FFT rfield = mtf.cast(mpm.ifft3d(fft_field, [x_dim, y_dim, z_dim]), tf.float32) # Compute errors err = mtf.reduce_max(mtf.abs(field - rfield)) return err
def _noisy_targets_from_spec(self, targets, noising_spec, losses=None): if noising_spec["type"] == "mask": # Replace a randomly-chosen noising_spec["prob"] of input tokens with 0. return targets * mtf.cast( mtf.greater(mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"]), targets.dtype) elif noising_spec["type"] == "random_zipfian": # Replace a randomly-chosen noising_spec["prob"] of input tokens. # Rather than drawing the replacement tokens uniformly, we sample from # a distribution favoring lower token-ids, assuming that the ids have # been assigned in frequency order. The probability of choosing an # id is proportional to 1/(id+10) logits = mtf.log(1.0 / (mtf.range( targets.mesh, self.targets_vocab_dim, dtype=tf.float32) + 10.0)) logits = mtf.broadcast(logits, new_shape=targets.shape + logits.shape) r = mtf.sample_with_temperature(logits, self.targets_vocab_dim) use_noise = mtf.less( mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"]) return mtf.where(use_noise, r, targets) elif noising_spec["type"] == "transformer": # Train a small transformer to fill in masked out values, then # sample from it. hparams = self._hparams if hparams.mode != tf.estimator.ModeKeys.TRAIN: raise NotImplementedError("Not implemented") noiser_hparams = copy.copy(self._hparams) noiser_hparams.del_hparam("mode") noiser_hparams.override_from_dict(noising_spec["overrides"]) with tf.variable_scope("noiser"): noiser = MtfTransformer(noiser_hparams, mode=hparams.mode, problem_hparams=self._problem_hparams) logits, loss = noiser._mtf_model_fn( # pylint: disable=protected-access self._original_features, targets.mesh) samples = mtf.sample_with_temperature(logits, self.targets_vocab_dim) losses.append(loss) return samples else: raise ValueError("unknown noising spec %s" % noising_spec)
def _noisy_targets_from_spec(self, targets, noising_spec, losses=None): if noising_spec["type"] == "mask": # Replace a randomly-chosen noising_spec["prob"] of input tokens with 0. return targets * mtf.cast( mtf.greater(mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"]), targets.dtype) elif noising_spec["type"] == "random_zipfian": # Replace a randomly-chosen noising_spec["prob"] of input tokens. # Rather than drawing the replacement tokens uniformly, we sample from # a distribution favoring lower token-ids, assuming that the ids have # been assigned in frequency order. The probability of choosing an # id is proportional to 1/(id+10) logits = mtf.log(1.0 / (mtf.range( targets.mesh, self.targets_vocab_dim, dtype=tf.float32) + 10.0)) logits = mtf.broadcast(logits, new_shape=targets.shape + logits.shape) r = mtf.sample_with_temperature(logits, self.targets_vocab_dim) use_noise = mtf.less( mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"]) return mtf.where(use_noise, r, targets) elif noising_spec["type"] == "transformer": # Train a small transformer to fill in masked out values, then # sample from it. hparams = self._hparams if hparams.mode != tf.estimator.ModeKeys.TRAIN: raise NotImplementedError("Not implemented") noiser_hparams = copy.copy(self._hparams) noiser_hparams.del_hparam("mode") noiser_hparams.override_from_dict(noising_spec["overrides"]) with tf.variable_scope("noiser"): noiser = MtfTransformer( noiser_hparams, mode=hparams.mode, problem_hparams=self._problem_hparams) logits, loss = noiser._mtf_model_fn( # pylint: disable=protected-access self._original_features, targets.mesh) samples = mtf.sample_with_temperature(logits, self.targets_vocab_dim) losses.append(loss) return samples else: raise ValueError("unknown noising spec %s" % noising_spec)
def add_position_timing_signal_func(self, context, x, step): """Add n-dimensional embedding as the position (horizontal) timing signal. Args: context: mtf context x: a tensor with shape [batch, length, depth] step: step Returns: a Tensor with the same shape as x. """ if not self.position_start_index: index = 0 elif self.position_start_index == "random": # Shift all positions randomly # TODO(dehghani): What would be reasonable for max number of shift? index = mtf.random_uniform(context.mesh, [], maxval=x.shape.dims[1].size, dtype=tf.int32) elif self.position_start_index == "step": # Shift positions based on the step if self.recurrence_type == "act": num_steps = self.act_max_steps else: num_steps = self.num_rec_steps index = mtf.cast(x.shape.dims[1].size * step / num_steps, dtype=tf.int32) length = context.length_dim channels = context.model.model_dim signal = self.get_timing_signal_1d(context, length, channels, start_index=index) if self.add_or_concat_timing_signal == "add": x_with_timing = x + mtf.cast(signal, x.dtype) # Unimplemented if self.add_or_concat_timing_signal == "concat": batch_dim = x.shape.dims[0] out_shape = mtf.Shape([batch_dim] + signal.shape.dims[1:]) signal_tiled = mtf.broadcast(signal, out_shape) x_with_timing = mtf.concat( (x, signal_tiled), concat_dim_name=signal_tiled.dimension_names[-1]) return x_with_timing
def model_fn(nc=64, batch_size=1): """ Example of function implementing a CNN and returning a value. """ # Create the mesh TF graph graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") # Define the named dimensions n_block_x = 4 n_block_y = 2 n_block_z = 1 batch_dim = mtf.Dimension("batch", batch_size`) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc//n_block_x) sy_dim = mtf.Dimension('sy_block', nc//n_block_y) sz_dim = mtf.Dimension('sz_block', nc//n_block_z) image_c_dim = mtf.Dimension('image_c', 3) hidden_dim = mtf.Dimension('h', 128) # Create some input data data = mtf.random_uniform(mesh, [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim, image_c_dim]) net = mtf.layers.conv3d_with_blocks(data, hidden_dim, filter_size=(3, 3, 3), strides=(1, 1, 1), padding='SAME', d_blocks_dim=nx_dim, h_blocks_dim=ny_dim) net = mtf.reduce_sum(net, output_shape=[batch_dim, hidden_dim] ) return net
def _top_2_gating(inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, variable_dtype, importance=None, name="top_2_gating"): """Compute gating for mixture-of-experts in TensorFlow. Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_use_second_place_loss: a boolean hparams.moe_second_policy_train: a string hparams.moe_second_policy_eval: a string hparams.moe_second_threshold: a float The returned forward assignment is a tensor used to map (via einsum) from the inputs to the expert_inputs. Likewise, the returned combine_tensor is used to map (via einsum) from the expert outputs to the outputs. Both the forward and backward assignments are mostly zeros. The shapes of the tensors are as follows. inputs: [<batch_dims>, group_size_dim, input_dim] importance: [<batch_dims>, group_size_dim] dispatch_tensor: [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] expert_inputs: [<batch_dims>, experts_dim, expert_capacity_dim, input_dim] expert_outputs: [<batch_dims>, experts_dim, expert_capacity_dim, output_dim] combine_tensor: [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] outputs: [<batch_dims>, group_size_dim, output_dim] "importance" is an optional tensor with one floating-point value for each input vector. If the importance of an input is 1.0, then we send it to up to 2 experts. If 0.0 < importance < 1.0, then we send it to at most one expert. If importance == 0.0, then we send it to no experts. We use "importance" at the second-level gating function of a hierarchical mixture of experts. Inputs to the first-choice expert-group get importance 1.0. Inputs to the second-choice expert group get importance 0.5. Inputs that represent padding get importance 0.0. Args: inputs: a mtf.Tensor with shape [<batch_dims>, group_size_dim, input_dim] outer_expert_dims: an optional list of dimensions. This is for the case where we are at an inner level of a hierarchical MoE. experts_dim: a Dimension (the number of experts) expert_capacity_dim: a Dimension (number of examples per group per expert) hparams: model hyperparameters. train: a boolean variable_dtype: a mtf.VariableDType importance: an optional tensor with shape [<batch_dims>, group_size_dim] name: an optional string Returns: dispatch_tensor: a Tensor with shape [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] combine_tensor: a Tensor with shape [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] loss: a mtf scalar Raises: ValueError: on illegal hyperparameters """ group_size_dim, unused_input_dim = inputs.shape.dims[-2:] raw_gates = mtf.layers.dense(inputs, experts_dim, use_bias=False, expert_dims=outer_expert_dims, variable_dtype=variable_dtype, name=name) raw_gates = mtf.softmax(raw_gates, experts_dim) # The internals of this function run in float32. # bfloat16 seems to reduce quality. raw_gates = mtf.to_float(raw_gates) expert_capacity_f = float(expert_capacity_dim.size) # FIND TOP 2 EXPERTS PER POSITON # Find the top expert for each position. shape=[batch, group] index_1, gate_1 = mtf.top_1(raw_gates, experts_dim) # [batch, group, experts] mask_1 = mtf.one_hot(index_1, experts_dim, dtype=raw_gates.dtype) density_1_proxy = raw_gates if importance is not None: mask_1 *= mtf.to_float(mtf.equal(importance, 1.0)) gate_1 *= mtf.to_float(mtf.equal(importance, 1.0)) density_1_proxy *= mtf.to_float(mtf.equal(importance, 1.0)) gates_without_top_1 = raw_gates * (1.0 - mask_1) # [batch, group] index_2, gate_2 = mtf.top_1(gates_without_top_1, experts_dim) # [batch, group, experts] mask_2 = mtf.one_hot(index_2, experts_dim, dtype=raw_gates.dtype) if importance is not None: mask_2 *= mtf.to_float(mtf.greater(importance, 0.0)) denom = gate_1 + gate_2 + 1e-9 gate_1 /= denom gate_2 /= denom # BALANCING LOSSES # shape = [batch, experts] # We want to equalize the fraction of the batch assigned to each expert density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_size_dim) # Something continuous that is correlated with what we want to equalize. density_1_proxy = mtf.reduce_mean(density_1_proxy, reduced_dim=group_size_dim) loss = (mtf.reduce_mean(density_1_proxy * density_1) * float(experts_dim.size * experts_dim.size)) if hparams.moe_use_second_place_loss: # Also add a loss to encourage all experts to be used equally also as the # second-place expert. Experimentally, this seems to be a wash. # We want to equalize the fraction of the batch assigned to each expert: density_2 = mtf.reduce_mean(mask_2, reduced_dim=group_size_dim) # As a proxy for density_2, we renormalize the raw gates after the top one # has been removed. normalized = gates_without_top_1 / (mtf.reduce_sum( gates_without_top_1, reduced_dim=experts_dim) + 1e-9) density_2_proxy = mtf.reduce_mean(normalized, reduced_dim=group_size_dim) loss_2 = (mtf.reduce_mean(density_2_proxy * density_2) * float(experts_dim.size * experts_dim.size)) loss += loss_2 * 0.5 # Depending on the policy in the hparams, we may drop out some of the # second-place experts. if train: policy = hparams.moe_second_policy_train threshold = hparams.moe_second_threshold_train else: policy = hparams.moe_second_policy_eval threshold = hparams.moe_second_threshold_eval if policy == "all": # Use second-place experts for all examples. pass elif policy == "none": # Never use second-place experts for all examples. mask_2 = mtf.zeros_like(mask_2) elif policy == "threshold": # Use second-place experts if gate_2 > threshold. mask_2 *= mtf.to_float(mtf.greater(gate_2, threshold)) elif policy == "random": # Use second-place experts with probablity min(1.0, gate_2 / threshold). mask_2 *= mtf.to_float( mtf.less(mtf.random_uniform(gate_2.mesh, gate_2.shape), gate_2 / max(threshold, 1e-9))) else: raise ValueError("Unknown policy %s" % policy) # COMPUTE ASSIGNMENT TO EXPERTS # [batch, group, experts] # This is the position within the expert's mini-batch for this sequence position_in_expert_1 = mtf.cumsum(mask_1, group_size_dim, exclusive=True) * mask_1 # Remove the elements that don't fit. [batch, group, experts] mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f)) # [batch, experts] # How many examples in this sequence go to this expert mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_size_dim) # [batch, group] - mostly ones, but zeros where something didn't fit mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim) # [batch, group] position_in_expert_1 = mtf.reduce_sum(position_in_expert_1, reduced_dim=experts_dim) # Weight assigned to first expert. [batch, group] gate_1 *= mask_1_flat # [batch, group, experts] position_in_expert_2 = ( mtf.cumsum(mask_2, group_size_dim, exclusive=True) + mask_1_count) position_in_expert_2 *= mask_2 mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f)) # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) gate_2 *= mask_2_flat position_in_expert_2 = mtf.reduce_sum(position_in_expert_2, reduced_dim=experts_dim) # [batch, group, experts, expert_capacity] combine_tensor = ( gate_1 * mask_1_flat * mtf.one_hot(index_1, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) + gate_2 * mask_2_flat * mtf.one_hot(index_2, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim)) combine_tensor = mtf.cast(combine_tensor, inputs.dtype) loss = mtf.cast(loss, inputs.dtype) dispatch_tensor = mtf.cast(mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype) return dispatch_tensor, combine_tensor, loss
def ut_function(state, step, halting_probability, remainders, n_updates, previous_state): """implements act (position-wise halting). Args: state: 3-D Tensor: [batch_size, length, channel] step: indicates number of steps taken so far halting_probability: halting probability remainders: act remainders n_updates: act n_updates previous_state: previous state Returns: transformed_state: transformed state step: step+1 halting_probability: halting probability remainders: act remainders n_updates: act n_updates new_state: new state """ state = self.step_preprocess(context, state, step) if self.act_type == "random": # random as halting probability p = mtf.random_uniform(context.mesh, shape=halting_probability.shape.dims, dtype=context.variable_dtype) else: last_dim_name = state.shape.dimension_names[-1] new_dims = [mtf.Dimension(last_dim_name, 1)] with tf.variable_scope("sigmoid_activation_for_pondering", reuse=tf.AUTO_REUSE): p = mtf.layers.dense(state, variable_dtype=context.variable_dtype, reduced_dims=[state.shape.dims[-1]], new_dims=new_dims, activation=mtf.sigmoid, use_bias=True) if self.act_type == "global": # average over all positions (as a global halting prob) p = mtf.reduce_mean(p, reduced_dim=p.shape.dims[1]) p = mtf.squeeze(p) else: # maintain position-wise probabilities new_shape = p.shape.dims[:-1] p = mtf.reshape(p, new_shape) # Mask for inputs which have not halted yet still_running = mtf.cast(mtf.less(halting_probability, 1.0), context.activation_dtype) # Mask of inputs which halted at this step new_halted = mtf.cast( mtf.greater(halting_probability + p * still_running, threshold), context.activation_dtype) * still_running # Mask of inputs which haven't halted, and didn't halt this step still_running = mtf.cast( mtf.less_equal(halting_probability + p * still_running, threshold), context.activation_dtype) * still_running # Add the halting probability for this step to the halting # probabilities for those input which haven't halted yet halting_probability += p * still_running # Compute remainders for the inputs which halted at this step remainders += new_halted * (1 - halting_probability) # Add the remainders to those inputs which halted at this step halting_probability += new_halted * remainders # Increment n_updates for all inputs which are still running n_updates += still_running + new_halted # Compute the weight to be applied to the new state and output # 0 when the input has already halted # p when the input hasn't halted yet # the remainders when it halted this step input_tensor = p * still_running + new_halted * remainders update_weights = input_tensor # apply transformation on the state transformed_state = state for _ in range(self.num_inrecurrence_layers): transformed_state = self.vanilla_transformer_layer( context, transformed_state, mask) # update running part in the weighted state and keep the rest new_state = ((transformed_state * update_weights) + (previous_state * (1 - update_weights))) if self.act_type == "accumulated": # Add in the weighted state new_state = (transformed_state * update_weights) + previous_state step += 1 return (transformed_state, step, halting_probability, remainders, n_updates, new_state)