def entmax_forward(x, alpha=1.3, dim=None, n_iter=50): assert alpha > 1 and alpha < 2, 'alpha must be between 1 and 2' _gp = lambda x, alpha: x ** (alpha - 1) _gp_inv = lambda x, alpha: mtf.pow(x, (1 / (alpha - 1))) _p = lambda x, alpha: _gp_inv(mtf.relu(x), alpha) dim = x.shape[-1] if dim is None else dim d = dim.size x = x * (alpha - 1) max_val = mtf.reduce_max(x, reduced_dim=dim) tau_lo = max_val - _gp(1, alpha) tau_hi = max_val - _gp(1 / d, alpha) f_lo = mtf.reduce_sum(_p(x - tau_lo, alpha), reduced_dim=dim) - 1 dm = tau_hi - tau_lo for _ in range(n_iter): dm = dm / 2 tau_m = tau_lo + dm p_m = _p(x - tau_m, alpha) f_m = mtf.reduce_sum(p_m, reduced_dim=dim) - 1 mask = mtf.greater_equal((f_m * f_lo), 0) tau_lo = mtf.where(mask, tau_m, tau_lo) p_m = p_m / mtf.reduce_sum(p_m, reduced_dim=dim) return p_m
def entmax_backward(explicit_inputs, all_inputs, forward_operations, outputs, output_grads, alpha = 1.3, dim = None, n_iter = 50): x, = explicit_inputs y, = outputs dY, = output_grads gppr = mtf.where(mtf.greater(y, 0), mtf.pow(y, (2 - alpha)), mtf.zeros_like(y)) dX = dY * gppr q = mtf.reduce_sum(dX, reduced_dim = dim) / mtf.reduce_sum(gppr, reduced_dim = dim) dX = dX - q * gppr return dX,
def clip_by_global_norm(grads, clip_norm): """Clip the grads by global norm.""" global_norm = mtf.sqrt( mtf.add_n( [mtf.reduce_sum(mtf.square(t)) for t in grads if t is not None])) multiplier = clip_norm / mtf.maximum(global_norm, clip_norm) clipped_grads = [None if t is None else t * multiplier for t in grads] return clipped_grads, global_norm
def widedeep(id_hldr, wt_hldr, vocab_dim, embed_dim, outdim, float16=None): logger.debug("[input tensor] (name,shape):({},{})".format( id_hldr.name, id_hldr.shape)) logger.debug("[input tensor] (name,shape):({},{})".format( wt_hldr.name, wt_hldr.shape)) if float16: embedding_table = mtf.layers.embedding_weights(id_hldr.mesh, vocab_dim, embed_dim, float16, "deep_embedding") deep_output = mtf.gather(embedding_table, id_hldr, vocab_dim) # deep_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim, variable_dtype=float16, name="deep_embedding") else: fp32 = mtf.VariableDType(tf.float32, tf.float32, tf.float32) embedding_table = mtf.layers.embedding_weights(id_hldr.mesh, vocab_dim, embed_dim, fp32, "deep_embedding") deep_output = mtf.gather(embedding_table, id_hldr, vocab_dim) # deep_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim, variable_dtype=fp32, name="deep_embedding") logger.debug("[output tensor] (name,shape):({},{})".format( deep_output.name, deep_output.shape)) expend_dim = mtf.Dimension('expend', size=1) embed_dim_one = mtf.Dimension('embed_dim_one', size=1) mask = mtf.reshape( wt_hldr, new_shape=[wt_hldr.shape.dims[0], wt_hldr.shape.dims[1], expend_dim], name='mask_reshape') logger.debug("[output tensor] (name,shape):({},{})".format( mask.name, mask.shape)) if float16: wide_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim_one, variable_dtype=float16, name="wide_embedding") else: fp32 = mtf.VariableDType(tf.float32, tf.float32, tf.float32) wide_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim_one, variable_dtype=fp32, name="wide_embedding") logger.debug("[output tensor] (name,shape):({},{})".format( wide_output.name, wide_output.shape)) wide_output = wide(wide_output, mask=mask, float16=float16) deep_output = deep(deep_output, mask=mask, float16=float16) result = mtf.add(wide_output, deep_output) result = mtf.reshape(result, new_shape=[wide_output.shape.dims[0], outdim], name='result_reshape') result = mtf.reduce_sum(result, reduced_dim=outdim) logger.debug("[output tensor] (name,shape):({},{})".format( result.name, result.shape)) return result, embedding_table
def toy_model(features, mesh): """A toy model implemented by mesh tensorlfow.""" batch_dim = mtf.Dimension('batch', FLAGS.batch_size) hidden_dim = mtf.Dimension('hidden', FLAGS.hidden_size) io_dim = mtf.Dimension('io', FLAGS.io_size) x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim])) h = mtf.layers.dense(x, hidden_dim, name='layer1', use_bias=False) y = mtf.layers.dense(h, io_dim, name='layer2', use_bias=False) loss = mtf.reduce_sum(mtf.square(y - x)) return y, loss
def get_masked_lm_output(self, positions, label_ids, label_weights): """Get loss and logits for the masked LM.""" input_tensor = self.get_sequence_output() output_weights = self.get_embedding_table() # [batch_size, num_position, hidden] input_tensor = mtf.gather(input_tensor, positions, self.seq_dim) with tf.variable_scope("cls/predictions"): # We apply one more non-linear transformation before the output layer. # This matrix is not used after pre-training. with tf.variable_scope("transform"): input_tensor = mtf.layers.dense( input_tensor, reduced_dims=[self.model_dim], new_dims=[self.model_dim], activation=get_activation(self.config.feedforward_intermediate_act), kernel_initializer=self.dense_initializer, use_bias=self.config.use_bias) input_tensor = self.normalize(input_tensor) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. output_bias = mtf.get_variable( input_tensor.mesh, name="output_bias", shape=[self.vocab_dim], initializer=tf.zeros_initializer()) logits = mtf.einsum([input_tensor, output_weights], reduced_dims=[self.model_dim]) + output_bias per_example_loss = mtf.layers.softmax_cross_entropy_with_logits( logits, label_ids, self.vocab_dim, z_loss=1e-4) # The `positions` tensor might be zero-padded (if the sequence is too # short to have the maximum number of predictions). The `label_weights` # tensor has a value of 1.0 for every real prediction and 0.0 for the # padding predictions. numerator = mtf.reduce_sum(label_weights * per_example_loss) denominator = mtf.reduce_sum(label_weights) + mtf.constant( input_tensor.mesh, 1e-5, dtype=tf.float32) loss = numerator / denominator return (loss, per_example_loss, logits)
def compute_loss(self, decoder, hidden, targets, context): """Computes the loss during training.""" logits = self._embedding.hidden_to_logits(hidden, context=context) soft_targets = mtf.one_hot(targets - self._start_token_id, self._vocab_dim, dtype=context.activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self._vocab_dim, z_loss=decoder.z_loss) padding_mask = mtf.layers.weights_nonzero( targets, dtype=context.activation_dtype) return (mtf.reduce_sum(loss * padding_mask) / decoder.loss_denominator(targets, context.num_microbatches))
def cond_fn(position, ids, *unused_states): """Should we run another loop iteration?""" past_end = mtf.greater_equal(position, length_dim.size) if max_steps: past_end = mtf.logical_or( past_end, mtf.greater_equal(position - initial_position, max_steps)) is_done = past_end if stop_at_token is not None: eos_count = mtf.reduce_sum( mtf.to_int32(mtf.equal(ids, stop_at_token)), reduced_dim=length_dim) has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) is_done = mtf.logical_or(is_done, has_additional_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done)
def compute_loss(self, decoder: transformer.Unitransformer, hidden: mtf.Tensor, targets: mtf.Tensor, context: transformer.Context) -> mtf.Tensor: """Returns the loss without computing a softmax over the entire vocab.""" loss = 0 tail_cluster_masks = [] for cluster in self._tail_clusters: cluster_mask = cluster.get_cluster_mask(targets) tail_cluster_masks.append(cluster_mask) if cluster.length_projection_factor == 1: targets_in_cluster = mtf.where(cluster_mask, targets, 0) hidden_in_cluster = mtf.where(cluster_mask, hidden, 0) else: # TODO(mmatena): Unfold the batch dim to get a super long sequence dim # to reduce the risk of overflowing the projection. proj_to_cluster_len = cluster.get_project_to_cluster_length( cluster_mask, dtype=targets.dtype) targets_in_cluster = mtf.einsum( [proj_to_cluster_len, targets], reduced_dims=[targets.shape.get_dim_by_name("length")]) hidden_in_cluster = mtf.einsum( [mtf.cast(proj_to_cluster_len, hidden.dtype), hidden], reduced_dims=[hidden.shape.get_dim_by_name("length")]) loss += cluster.compute_loss(decoder, hidden_in_cluster, targets_in_cluster, context) tail_clusters_dim = mtf.Dimension("tail_clusters", len(tail_cluster_masks)) tail_node_targets = mtf.reduce_sum( mtf.stack([(self._head_cluster.end_token_id + i) * mtf.cast(mask, targets.dtype) for i, mask in enumerate(tail_cluster_masks)], tail_clusters_dim.name), reduced_dim=tail_clusters_dim) head_targets = mtf.where(mtf.cast(tail_node_targets, tf.bool), tail_node_targets, targets) loss += self._head_cluster.compute_loss(decoder, hidden, head_targets, context) return loss
def model_fn(nc=64, batch_size=1): """ Example of function implementing a CNN and returning a value. """ # Create the mesh TF graph graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") # Define the named dimensions n_block_x = 4 n_block_y = 2 n_block_z = 1 batch_dim = mtf.Dimension("batch", batch_size`) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc//n_block_x) sy_dim = mtf.Dimension('sy_block', nc//n_block_y) sz_dim = mtf.Dimension('sz_block', nc//n_block_z) image_c_dim = mtf.Dimension('image_c', 3) hidden_dim = mtf.Dimension('h', 128) # Create some input data data = mtf.random_uniform(mesh, [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim, image_c_dim]) net = mtf.layers.conv3d_with_blocks(data, hidden_dim, filter_size=(3, 3, 3), strides=(1, 1, 1), padding='SAME', d_blocks_dim=nx_dim, h_blocks_dim=ny_dim) net = mtf.reduce_sum(net, output_shape=[batch_dim, hidden_dim] ) return net
def entmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0): if targets.dtype.is_integer: # hard targets if (set(targets.shape.dims) != set(logits.shape.dims).difference([vocab_dim])): raise ValueError( "softmax_cross_entropy_with_logits with hard targets " "dims in targets=%s should be dims in logits=%s other than " "vocab_dim=%s" % (targets, logits, vocab_dim)) targets = mtf.one_hot(targets, vocab_dim, dtype=logits.dtype) elif set(targets.shape.dims) != set(logits.shape.dims): raise ValueError( "softmax_cross_entropy_with_logits with soft targets " "dims in targets=%s should be dims in logits=%s" % (targets, logits)) if vocab_dim not in logits.shape.dims: raise ValueError("vocab_dim must be in logits.shape.dims") log_entmax = mtf.log(entmax(logits, dim=vocab_dim)) loss = mtf.negative( mtf.reduce_sum(log_entmax * targets, reduced_dim=vocab_dim)) return loss
def _switch_gating(inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, variable_dtype, importance=None, name="switch_gating", num_microbatches=None): """Compute a switch top-1 gating with no-token-left behind behavior.""" # SELECT EXPERT if train: policy = hparams.moe_rand_1_policy_train else: policy = hparams.moe_rand_1_policy_eval # Input perturbations if train and policy == "input_jitter": inputs = mtf.layers.multiplicative_jitter(inputs, hparams.moe_rand_1_jitter) gate_logits = mtf.layers.dense( inputs, experts_dim, use_bias=False, expert_dims=outer_expert_dims, variable_dtype=variable_dtype, name=name) raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim) # The internals of this function run in float32. # bfloat16 seems to reduce quality. raw_gates = mtf.to_float(raw_gates) # Top-k operation k_dim = mtf.Dimension("k", hparams.moe_switch_top_k) expert_gate, expert_index = mtf.top_k( raw_gates, reduced_dim=experts_dim, k_dim=k_dim) expert_mask = mtf.one_hot(expert_index, experts_dim) # LOAD BALANCING LOSS outer_batch_dim = inputs.shape[0] batch_dim = inputs.shape[1] group_size_dim = inputs.shape[-2] density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim) density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim) if importance is not None: expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype) expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype) density_1_proxy *= mtf.cast( mtf.equal(importance, 1.0), dtype=raw_gates.dtype) loss = ( mtf.reduce_mean(density_1_proxy * density_1) * float(experts_dim.size * experts_dim.size)) if num_microbatches and num_microbatches > 1: tf.logging.info("Dividing load-balance loss by num_microbatches={}".format( num_microbatches)) loss /= num_microbatches # Logging if train: entropy = mtf.reduce_sum( -raw_gates * mtf.log(raw_gates + 1e-9), reduced_dim=experts_dim) batch_entropy = mtf.reduce_mean(entropy) mtf.scalar_summary(name + "/entropy", batch_entropy) mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim]) total_routed = mtf.reduce_sum(mask_count_experts) expert_fraction = mtf.to_float(mask_count_experts / total_routed) split_fractions = mtf.split( expert_fraction, split_dim=experts_dim, num_or_size_splits=experts_dim.size) for fraction in split_fractions: mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"), mtf.reduce_mean(fraction)) mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss)) # COMPUTE ASSIGNMENT TO EXPERT # Iteratively route tokens (no-token-left-behind). The idea is to route as # many tokens as possible to top-i before then trying top-(i+1). top_k_masks = mtf.split( expert_mask, split_dim=k_dim, num_or_size_splits=k_dim.size) top_k_gates = mtf.split( expert_gate, split_dim=k_dim, num_or_size_splits=k_dim.size) top_k_indices = mtf.split( expert_index, split_dim=k_dim, num_or_size_splits=k_dim.size) # Tensors cumulative values over the iterative process. combine_tensor = mtf.constant( inputs.mesh, value=0, shape=[outer_batch_dim, batch_dim, experts_dim, expert_capacity_dim]) cum_tokens = mtf.constant( inputs.mesh, value=0, shape=[outer_batch_dim, batch_dim, experts_dim]) tokens_left_to_route = mtf.constant( inputs.mesh, value=1., shape=[outer_batch_dim, batch_dim, group_size_dim]) expert_capacity_float = float(expert_capacity_dim.size) for (top_i_mask, top_i_gate, top_i_index) in zip(top_k_masks, top_k_gates, top_k_indices): top_i_mask = mtf.reshape( top_i_mask, new_shape=[outer_batch_dim, batch_dim, group_size_dim, experts_dim]) # Operate only on the unrouted tokens. top_i_mask *= tokens_left_to_route # Record cumulative number of tokens to each expert across iterations. cumulative_tokens_in_expert = cum_tokens + mtf.cumsum( top_i_mask, group_size_dim) expert_overflow = mtf.to_float( mtf.less_equal(cumulative_tokens_in_expert, expert_capacity_float)) output_i_tokens = top_i_mask * expert_overflow # Update the cumulative tokens routed to each expert. cum_tokens += mtf.reduce_sum(output_i_tokens, reduced_dim=group_size_dim) tokens_left_to_route -= ( mtf.reduce_sum(output_i_tokens, reduced_dim=experts_dim)) # Combine-tensor for this iteration output_i_tokens_flat = mtf.reduce_sum( output_i_tokens, reduced_dim=experts_dim) position_in_expert = cumulative_tokens_in_expert - 1 top_i_combine_tensor = ( top_i_gate * output_i_tokens_flat * mtf.one_hot(top_i_index, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert), expert_capacity_dim)) combine_tensor += top_i_combine_tensor # Match the inputs dtype. combine_tensor = mtf.cast(combine_tensor, inputs.dtype) loss = mtf.cast(loss, inputs.dtype) dispatch_tensor = mtf.cast( mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype) return dispatch_tensor, combine_tensor, loss
def _rand_1_gating( inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, variable_dtype, importance=None, name="rand_1_gating", num_microbatches=None): """Compute a random top-1 gating.""" # SELECT EXPERT if train: policy = hparams.moe_rand_1_policy_train else: policy = hparams.moe_rand_1_policy_eval # The internals of this function run in float32. # bfloat16 seems to reduce quality. gate_inputs = mtf.to_float(inputs) # Input perturbations if train and policy == "input_dropout": gate_inputs = mtf.dropout(gate_inputs, 1.0 - hparams.moe_rand_1_dropout) elif train and policy == "input_jitter": gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs, hparams.moe_rand_1_jitter) gate_logits = mtf.layers.dense( gate_inputs, experts_dim, use_bias=False, expert_dims=outer_expert_dims, variable_dtype=variable_dtype, name=name) raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim) if policy == "argmax" or policy == "input_dropout" or policy == "input_jitter": expert_gate, expert_index = mtf.top_1(raw_gates, reduced_dim=experts_dim) elif policy == "sample": expert_index = mtf.sample_with_temperature( gate_logits, experts_dim, temperature=hparams.moe_rand_1_temperature) expert_gate = mtf.gather(raw_gates, expert_index, dim=experts_dim) else: raise ValueError("Unknown rand_1 policy %s" % policy) expert_mask = mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype) # LOAD BALANCING LOSS # TODO(liamfedus): Check entropy loss. group_size_dim = inputs.shape[-2] density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim) density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim) if importance is not None: expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype) expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype) density_1_proxy *= mtf.cast( mtf.equal(importance, 1.0), dtype=raw_gates.dtype) loss = ( mtf.reduce_mean(density_1_proxy * density_1) * float(experts_dim.size * experts_dim.size)) if num_microbatches and num_microbatches > 1: tf.logging.info("Dividing load-balance loss by num_microbatches={}".format( num_microbatches)) loss /= num_microbatches # Logging if train: entropy = mtf.reduce_sum(-raw_gates * mtf.log(raw_gates + 1e-9), reduced_dim=experts_dim) batch_entropy = mtf.reduce_mean(entropy) mtf.scalar_summary(name + "/entropy", batch_entropy) mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim]) total_routed = mtf.reduce_sum(mask_count_experts) expert_fraction = mtf.to_float(mask_count_experts / total_routed) split_fractions = mtf.split( expert_fraction, split_dim=experts_dim, num_or_size_splits=experts_dim.size) for fraction in split_fractions: mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"), mtf.reduce_mean(fraction)) mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss)) # COMPUTE ASSIGNMENT TO EXPERT # Experts have a limited capacity, ensure we do not exceed it. Construct # the batch indices, to each expert, with position_in_expert position_in_expert = mtf.cumsum( expert_mask, group_size_dim, exclusive=True) * expert_mask position_in_expert = mtf.cast(position_in_expert, dtype=raw_gates.dtype) # Keep only tokens that fit within expert_capacity. expert_capacity_float = float(expert_capacity_dim.size) expert_mask *= mtf.cast( mtf.less(position_in_expert, expert_capacity_float), dtype=raw_gates.dtype) expert_mask_flat = mtf.reduce_sum(expert_mask, reduced_dim=experts_dim) # Mask out the experts that have overflowed expert capacity. Sparsify the # expert_gate. expert_gate *= expert_mask_flat combine_tensor = ( expert_gate * expert_mask_flat * mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype) * mtf.one_hot( mtf.to_int32(position_in_expert), expert_capacity_dim, dtype=raw_gates.dtype)) # Match the inputs dtype. combine_tensor = mtf.cast(combine_tensor, inputs.dtype) loss = mtf.cast(loss, inputs.dtype) dispatch_tensor = mtf.cast( mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype) return dispatch_tensor, combine_tensor, loss
def recon_prototype(mesh, data, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print(dtype, npdtype) # Compute a few things first, using simple tensorflow stages = np.linspace(a0, a, nsteps, endpoint=True) #graph = mtf.Graph() #mesh = mtf.Mesh(graph, "my_mesh") # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition downsampling_factor = 2 lnc = nc // 2**downsampling_factor fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) # Dimensions of the low resolution grid x_dim = mtf.Dimension("nx_lr", lnc) y_dim = mtf.Dimension("ny_lr", lnc) z_dim = mtf.Dimension("nz_lr", lnc) tx_dim = mtf.Dimension("tx_lr", lnc) ty_dim = mtf.Dimension("ty_lr", lnc) tz_dim = mtf.Dimension("tz_lr", lnc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False, dtype=npdtype) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype(npdtype), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype(npdtype), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype(npdtype), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc], symmetric=False, dtype=npdtype) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype(npdtype) / 2**downsampling_factor, shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype(npdtype) / 2**downsampling_factor, shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype(npdtype) / 2**downsampling_factor, shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] # kvec for high resolution blocks padded_sx_dim = mtf.Dimension('padded_sx_block', nc // n_block_x + 2 * halo_size) padded_sy_dim = mtf.Dimension('padded_sy_block', nc // n_block_y + 2 * halo_size) padded_sz_dim = mtf.Dimension('padded_sz_block', nc // n_block_z + 2 * halo_size) kvec_hr = flowpm.kernels.fftk([ nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size, nc // n_block_z + 2 * halo_size ], symmetric=False, dtype=npdtype) kx_hr = mtf.import_tf_tensor(mesh, kvec_hr[0].squeeze().astype(npdtype), shape=[padded_sx_dim]) ky_hr = mtf.import_tf_tensor(mesh, kvec_hr[1].squeeze().astype(npdtype), shape=[padded_sy_dim]) kz_hr = mtf.import_tf_tensor(mesh, kvec_hr[2].squeeze().astype(npdtype), shape=[padded_sz_dim]) kv_hr = [ky_hr, kz_hr, kx_hr] # kvec for prior blocks prior_sx_dim = mtf.Dimension('prior_sx_block', nc // n_block_x) prior_sy_dim = mtf.Dimension('prior_sy_block', nc // n_block_y) prior_sz_dim = mtf.Dimension('prior_sz_block', nc // n_block_z) kvec_pr = flowpm.kernels.fftk( [nc // n_block_x, nc // n_block_y, nc // n_block_z], symmetric=False, dtype=npdtype) kx_pr = mtf.import_tf_tensor(mesh, kvec_pr[0].squeeze().astype(npdtype), shape=[prior_sx_dim]) ky_pr = mtf.import_tf_tensor(mesh, kvec_pr[1].squeeze().astype(npdtype), shape=[prior_sy_dim]) kz_pr = mtf.import_tf_tensor(mesh, kvec_pr[2].squeeze().astype(npdtype), shape=[prior_sz_dim]) kv_pr = [ky_pr, kz_pr, kx_pr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, x_dim, y_dim, z_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] ## Compute initial initial conditions distributed fieldvar = mtf.get_variable(mesh, 'linear', hr_shape) input_field = tf.placeholder(data.dtype, [ batch_size, n_block_x, n_block_y, n_block_z, nc // n_block_x, nc // n_block_y, nc // n_block_z ]) mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=hr_shape) linearop = mtf.assign(fieldvar, mtfinp) # field = fieldvar initc = mtf.slicewise(lambda x: x[:, 0, 0, 0], [field], output_dtype=tf.float32, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) # for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size) field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)]) high = field low = mesh_utils.downsample(field, downsampling_factor, antialias=True) low = mtf.reshape(low, low.shape[:-1]) high = mtf.reshape(high, high.shape[:-1]) for block_size_dim in hr_shape[-3:]: low = mtf.slice(low, halo_size // 2**downsampling_factor, block_size_dim.size // 2**downsampling_factor, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb low = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low], output_dtype=initc.dtype, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=lr_shape[:-1] + hr_shape[:4]) # Here we can run our nbody if FLAGS.nbody: state = mtfpm.lpt_init(low, high, 0.1, kv_lr, kv_hr, halo_size, hr_shape, lr_shape, part_shape[1:], downsampling_factor=downsampling_factor, antialias=True) final_state = mtfpm.nbody(state, stages, lr_shape, hr_shape, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor) else: final_state = mtfpm.lpt_init(low, high, stages[-1], kv_lr, kv_hr, halo_size, hr_shape, lr_shape, part_shape[1:], downsampling_factor=downsampling_factor, antialias=True) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=dtype, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) mtfdata = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(data), shape=shape) # Get prior k_dims_pr = [d.shape[0] for d in kv_pr] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv_pr, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 # Total loss diff = (final_field - mtfdata) R0 = tf.placeholder(tf.float32, shape=()) def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype) cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv_pr, output_dtype=cdtype) diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype) chisq = mtf.reduce_sum(mtf.square(diff)) loss = chisq + prior #return initc, final_field, loss, linearop, input_field nyq = np.pi * nc / bs def _cwise_highpass(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype) return kfield * (1 - wts) var_grads = mtf.gradients([loss], [fieldvar]) cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype) cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv_pr, output_dtype=cdtype) var_grads = [ mesh_utils.c2r3d(cgrads, var_grads[0].shape[-3:], dtype=dtype) ] lr = tf.placeholder(tf.float32, shape=()) update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr) return initc, final_field, loss, var_grads, update_op, linearop, input_field, lr, R0
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "encdec": inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad(inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = (mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack( x, hparams.encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension(x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num, layer_type in enumerate(hparams.decoder_layers): if layer_type == "enc_att": with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num): q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.master_dtype, self.slice_dtype, self.activation_dtype) k = mtf.einsum([encoder_output, k_var], mtf.Shape(self.batch_dims + [ self.heads_dim, self.memory_length_dim, self.kv_dim ])) v = mtf.einsum([encoder_output, v_var], mtf.Shape(self.batch_dims + [ self.heads_dim, self.memory_length_dim, self.kv_dim ])) encdec_tensors.append((q_var, o_var, k, v)) else: encdec_tensors.append(None) partial_targets = None elif hparams.transformer_type == "decoder": encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) else: raise ValueError("hparams.model_type = %s not yet supported" % hparams.transformer_type) local_attention_window = mtf.Dimension( "local_attention_window", hparams.local_attention_window_size) if hparams.beam_size == 1: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) kv_shape = mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape( self.batch_dims + [self.heads_dim, local_attention_window, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [ beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim ]) local_kv_shape = mtf.Shape(self.batch_dims + [ beam_dim, self.heads_dim, local_attention_window, self.kv_dim ]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_states = [] for layer in hparams.decoder_layers: if layer == "att": initial_states.extend( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2) elif layer == "local_att": initial_states.extend([ mtf.zeros( mesh, local_kv_shape, dtype=self.activation_dtype) ] * 2) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_states = self._layer_stack( x, hparams.decoder_layers, encdec_attention_mask=encoder_attention_mask, step_num=step_num, encdec_tensors=encdec_tensors, states=states) logits = mtf.matmul(x, softmax_var) return logits, new_states if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf.beam_search.greedy_decode(logits_fn, initial_ids, temperature=temperature, initial_states=initial_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if hparams.transformer_type == "encdec": input_length = mtf.reduce_sum(mtf.to_float( mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf.beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_states, decode_length=decode_length, use_tpu=hparams.use_tpu, dtype=self.activation_dtype) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
def unet_with_spatial_partition(mesh, mesh_impl, dataset_str, images, labels): """Builds the UNet model graph, train op and eval metrics. Args: mesh: a MeshTensorflow.mesh object. mesh_impl: a mesh implementation, such as SimdMeshImpl and PlacementMeshImpl. dataset_str: a string of either train or eval. This is used for batch_norm. images: a laid out Tensor with shape [batch, x, y, num_channels] or [batch, x, y, z, num_channels]. labels: a laid out Tensor with shape [batch, x, y, num_classes] or [batch, x, y, z, num_classes]. Returns: Prediction and loss. """ is_training = (dataset_str == 'train') if dataset_str == 'train': batch_dim = mtf.Dimension('batch', FLAGS.batch_size_train) else: assert dataset_str == 'eval' batch_dim = mtf.Dimension('batch', FLAGS.batch_size_eval) image_nx_dim = mtf.Dimension('image_nx_block', FLAGS.image_nx_block) image_ny_dim = mtf.Dimension('image_ny_block', FLAGS.image_ny_block) image_sx_dim = mtf.Dimension('image_sx_block', FLAGS.ct_resolution // FLAGS.image_nx_block) image_sy_dim = mtf.Dimension('image_sy_block', FLAGS.ct_resolution // FLAGS.image_ny_block) image_sz_dim = mtf.Dimension('image_sz_block', FLAGS.ct_resolution) image_c_dim = mtf.Dimension('image_c', FLAGS.image_c) label_c_dim = mtf.Dimension('label_c', FLAGS.label_c) mtf_images_shape, mtf_labels_shape = get_input_mtf_shapes(dataset_str) mtf_dtype = tf.as_dtype(FLAGS.mtf_dtype) variable_dtype = mtf.VariableDType(mtf_dtype, mtf_dtype, mtf_dtype) # Import input features. x = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(images), mtf_images_shape) x = mtf.cast(x, mtf_dtype) # Import ground truth labels. t = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(labels), mtf_labels_shape) t = mtf.cast(t, mtf_dtype) # Transpose the blocks. if FLAGS.sampled_2d_slices: x = mtf.transpose(x, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, image_c_dim ]) t = mtf.transpose(t, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, label_c_dim ]) else: x = mtf.transpose(x, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, image_sz_dim, image_c_dim ]) t = mtf.transpose(t, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, image_sz_dim, label_c_dim ]) # Network. levels = [] all_bn_update_ops = [] # add levels with convolution or down-sampling for depth in range(FLAGS.network_depth): for n_conv in range(FLAGS.n_conv_per_block): if depth == 0 and n_conv == 0: # no dropout in 1st layer. dropout_keep_p = 1.0 else: dropout_keep_p = FLAGS.dropout_keep_p x, bn_update_ops = conv_with_spatial_partition( x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim, FLAGS.n_base_filters * (2**depth), dropout_keep_p, FLAGS.with_batch_norm, is_training, 'conv_{}_{}'.format(depth, n_conv), variable_dtype, 'conv_down_{}_{}'.format(depth, n_conv)) all_bn_update_ops.extend(bn_update_ops) levels.append(x) if depth < FLAGS.network_depth - 1: if FLAGS.sampled_2d_slices: x = mtf.layers.max_pool2d(x, ksize=(2, 2)) else: x = mtf.layers.max_pool3d(x, ksize=(2, 2, 2)) # add levels with up-convolution or up-sampling for depth in range(FLAGS.network_depth - 1)[::-1]: x = deconv_with_spatial_partition( x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim, FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p, 'conv_{}_{}'.format(depth, FLAGS.n_conv_per_block - 1), variable_dtype, 'deconv_{}_0'.format(depth)) x = mtf.concat([x, levels[depth]], concat_dim_name='conv_{}_{}'.format( depth, FLAGS.n_conv_per_block - 1)) for n_conv in range(FLAGS.n_conv_per_block): x, bn_update_ops = conv_with_spatial_partition( x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim, FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p, FLAGS.with_batch_norm, is_training, 'conv_{}_{}'.format(depth, n_conv), variable_dtype, 'conv_up_{}_{}'.format(depth, n_conv)) all_bn_update_ops.extend(bn_update_ops) # no dropout in the final layer. if FLAGS.sampled_2d_slices: y = mtf.layers.conv2d_with_blocks( x, mtf.Dimension('label_c', FLAGS.label_c), filter_size=(1, 1), strides=(1, 1), padding='SAME', h_blocks_dim=image_nx_dim, w_blocks_dim=image_ny_dim, variable_dtype=variable_dtype, name='final_conv_{}'.format(FLAGS.label_c), ) else: y = mtf.layers.conv3d_with_blocks( x, mtf.Dimension('label_c', FLAGS.label_c), filter_size=(1, 1, 1), strides=(1, 1, 1), padding='SAME', d_blocks_dim=image_nx_dim, h_blocks_dim=image_ny_dim, variable_dtype=variable_dtype, name='final_conv_{}'.format(FLAGS.label_c), ) # use mtf.constant to make sure there is no CPU-side constants. def scalar(v, dtype): return mtf.constant(mesh, v, shape=[], dtype=dtype) argmax_t = mtf.argmax(t, label_c_dim) liver_t = mtf.cast(mtf.equal(argmax_t, scalar(1, tf.int32)), mtf_dtype) lesion_t = mtf.cast(mtf.equal(argmax_t, scalar(2, tf.int32)), mtf_dtype) argmax_y = mtf.argmax(y, label_c_dim) lesion_y = mtf.cast(mtf.equal(argmax_y, scalar(2, tf.int32)), mtf_dtype) # summary of class ratios. lesion_pred_ratio = mtf.reduce_mean(lesion_y) lesion_label_ratio = mtf.reduce_mean(lesion_t) # summary of accuracy. accuracy = mtf.reduce_mean( mtf.cast(mtf.equal(argmax_y, argmax_t), mtf_dtype)) # Cross-entropy loss. Up-weight the liver region. pixel_loss = mtf.layers.softmax_cross_entropy_with_logits( y, t, label_c_dim) pixel_weight = scalar(1, mtf_dtype) + \ liver_t * scalar(FLAGS.xen_liver_weight - 1, mtf_dtype) + \ lesion_t * scalar(FLAGS.xen_lesion_weight - FLAGS.xen_liver_weight, mtf_dtype) loss_xen = mtf.reduce_mean(pixel_loss * pixel_weight) # Dice loss y_prob = mtf.softmax(y, reduced_dim=label_c_dim) lesion_prob = mtf.reduce_sum(mtf.slice(y_prob, 2, 1, 'label_c'), reduced_dim=mtf.Dimension('label_c', 1)) prob_intersect = mtf.reduce_sum(lesion_prob * lesion_t, output_shape=mtf.Shape([batch_dim])) prob_area_sum = mtf.reduce_sum(lesion_prob + lesion_t, output_shape=mtf.Shape([batch_dim])) loss_dice_per_case = mtf.reduce_mean( scalar(-2, mtf_dtype) * prob_intersect / (prob_area_sum + scalar(FLAGS.dice_epsilon, mtf_dtype))) loss_dice_global = scalar(-2, mtf_dtype) * mtf.reduce_sum( prob_intersect) / (mtf.reduce_sum(prob_area_sum) + scalar(FLAGS.dice_epsilon, mtf_dtype)) loss_dice = (loss_dice_per_case + loss_dice_global) * scalar( 0.5, mtf_dtype) loss = scalar(FLAGS.dice_loss_weight, mtf_dtype) * loss_dice + scalar( 1 - FLAGS.dice_loss_weight, mtf_dtype) * loss_xen intersect = mtf.reduce_sum(lesion_y * lesion_t, output_shape=mtf.Shape([batch_dim])) area_sum = mtf.reduce_sum(lesion_y + lesion_t, output_shape=mtf.Shape([batch_dim])) # summary of dice. dice_per_case = mtf.reduce_mean( scalar(2, mtf_dtype) * intersect / (area_sum + scalar(0.000001, mtf_dtype))) dice_global = scalar(2, mtf_dtype) * mtf.reduce_sum(intersect) / ( mtf.reduce_sum(area_sum) + scalar(0.000001, mtf_dtype)) eval_metrics = { 'lesion_pred_ratio': lesion_pred_ratio, 'lesion_label_ratio': lesion_label_ratio, 'accuracy_of_all_classes': accuracy, 'lesion_dice_per_case': dice_per_case, 'lesion_dice_global': dice_global, 'loss_xen': loss_xen, 'loss_dice': loss_dice, 'loss_dice_per_case': loss_dice_per_case, 'loss_dice_global': loss_dice_global, } if FLAGS.sampled_2d_slices: y_prob_downsampled = mtf.layers.avg_pool2d( y_prob, ksize=(FLAGS.pred_downsample, ) * 2) if FLAGS.output_ground_truth: lesion_gt_downsampled = mtf.layers.avg_pool2d( mtf.slice(t, 2, 1, 'label_c'), ksize=(FLAGS.pred_downsample, ) * 2) else: y_prob_downsampled = mtf.layers.avg_pool3d( y_prob, ksize=(FLAGS.pred_downsample, ) * 3) if FLAGS.output_ground_truth: lesion_gt_downsampled = mtf.layers.avg_pool3d( mtf.slice(t, 2, 1, 'label_c'), ksize=(FLAGS.pred_downsample, ) * 3) liver_prob_downsampled = mtf.slice(y_prob_downsampled, 1, 1, 'label_c') lesion_prob_downsampled = mtf.slice(y_prob_downsampled, 2, 1, 'label_c') preds = [ mtf.reduce_sum(liver_prob_downsampled, reduced_dim=mtf.Dimension('label_c', 1)), mtf.reduce_sum(lesion_prob_downsampled, reduced_dim=mtf.Dimension('label_c', 1)) ] if FLAGS.output_ground_truth: preds.append( mtf.reduce_sum(lesion_gt_downsampled, reduced_dim=mtf.Dimension('label_c', 1))) preds.extend([intersect, area_sum]) return preds, loss, eval_metrics, all_bn_update_ops
def decode(self, inputs, variable_dtype=mtf.VariableDType(tf.float32), beam_size=1, alpha=0.6, temperature=0.0, sampling_keep_top_k=-1, decode_length_multiplier=1.5, decode_length_constant=10, max_decode_length=None): """Sampling or beam search for Funnel Transformer. Args: inputs: a Tensor with shape [<batch_dims>, beam_dim, length_dim] variable_dtype: a mtf.VariableDType beam_size: an integer >= 1 alpha: a floating point value (length bonus for beam search) temperature: a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. sampling_keep_top_k: a value between 1 and vocab_size used to sample from only the k most likely logits. Set to -1 to sample from all logits. decode_length_multiplier: a float decode_length_constant: a float max_decode_length: an optional integer Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ encoder_layer_outputs = [] shared_params = self._shared_params(inputs.mesh, variable_dtype) encoder_sequence_id = mtf.minimum(inputs, 1) encoder_output, encoder_loss = self.encoder.call_simple( inputs=inputs, targets=None, compute_loss=False, mode=tf.estimator.ModeKeys.PREDICT, variable_dtype=variable_dtype, sequence_id=encoder_sequence_id, shared_params=shared_params, layer_outputs=encoder_layer_outputs) del encoder_loss encoder_output = mtf.layers.rename_length_to_memory_length( encoder_output) # The sequence_id is updated inside the layer_stack due to pooling. So we # need to use the updated sequence_id stored in the context. encoder_sequence_id = self.encoder.layer_stack.context.sequence_id encoder_sequence_id = mtf.layers.rename_length_to_memory_length( encoder_sequence_id) batch_dims = inputs.shape[:-1] length_dim = inputs.shape[-1] if max_decode_length is None: decode_length_dim = length_dim else: decode_length_dim = mtf.Dimension("length", max_decode_length) if beam_size == 1: ids_shape = mtf.Shape(batch_dims + [decode_length_dim]) partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) return self.decoder.sample_autoregressive( partial_sequences, temperature=temperature, sampling_keep_top_k=sampling_keep_top_k, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, encoder_inputs=mtf.layers.rename_length_to_memory_length( inputs), shared_params=shared_params, has_partial_sequences=False, encoder_layer_outputs=encoder_layer_outputs) else: if temperature != 0: raise ValueError( "don't know how to beam search with nonzero temperature") if sampling_keep_top_k != -1: raise ValueError( "don't know how to beam search with top-k value other than -1." ) # beam search beam_dim = mtf.Dimension("beam", beam_size) ids_shape = mtf.Shape(batch_dims + [beam_dim, decode_length_dim]) partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) input_length = mtf.reduce_sum(mtf.to_float( mtf.cast(inputs, tf.bool)), reduced_dim=length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * decode_length_multiplier + decode_length_constant, tf.int32) return self.decoder.beam_search( partial_sequences, decode_length, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, encoder_inputs=inputs, alpha=alpha, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs)
def benchmark_model(mesh): """ Initializes a 3D volume with random noise, and execute a forward FFT """ # Setup parameters bs = FLAGS.box_size nc = FLAGS.cube_size batch_size = FLAGS.batch_size a0 = FLAGS.a0 a = 1.0 nsteps = FLAGS.pm_steps # Compute a few things first, using simple tensorflow klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) stages = np.linspace(a0, a, nsteps, endpoint=True) # Initialize the integration steps stages = np.linspace(FLAGS.a0, 1.0, FLAGS.pm_steps, endpoint=True) # Generate a batch of 3D initial conditions initial_conditions = flowpm.linear_field( nc, # size of the cube bs, # Physical size of the cube ipklin, # Initial power spectrum batch_size=batch_size) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) from flowpm.kernels import laplace_kernel, gradient_kernel lap = tf.cast(laplace_kernel(kvec), tf.complex64) grad_x = gradient_kernel(kvec, 0) grad_y = gradient_kernel(kvec, 1) grad_z = gradient_kernel(kvec, 2) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = 8 n_block_y = 4 n_block_z = 1 halo_size = 4 # Parameters of the large scales decomposition downsampling_factor = 2 lnc = nc // 2**downsampling_factor fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) # Dimensions of the low resolution grid tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) batch_dim = mtf.Dimension("batch", batch_size) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] # kvec for high resolution blocks shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv) state = mtfpm.lpt_init_single( initc, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) #state = mtfpm.lpt_init(low, high, 0.1, kv_lr, kv_hr, halo_size, hr_shape, lr_shape, # part_shape[1:], downsampling_factor=downsampling_factor, antialias=True,) # Here we can run our nbody final_state = state #mtfpm.nbody(state, stages, lr_shape, hr_shape, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) #final_field = mtf.reshape(final_field, [batch_dim, fx_dim, fy_dim, fz_dim]) # Hack usisng custom reshape because mesh is pretty dumb final_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [final_field], output_dtype=tf.float32, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) return mtf.reduce_sum(final_field)
def gradient_based_subword_tokenization(x, length_dim, max_subword_length=4, downsample=None, use_offsets=False, consider_chars_as_blocks=False, use_block_pos_embedding=False, share_block_kernel=False, memory_embeddings=0, context=None, block_mixing_mode=None, activation="softmax", downsample_function="mean"): """Implements GBSWT from Charformer. Args: x: a Tensor containing length_dim length_dim: a Dimension max_subword_length: integer downsample: integer. use_offsets: boolean. consider_chars_as_blocks: boolean. use_block_pos_embedding: boolean. share_block_kernel: boolean. memory_embeddings: integer. context: Context. block_mixing_mode: Str for block mixing. activation: Str for block ranking. downsample_function: Str, supports mean/linformer for now. Returns: a Tensor with the same shape as x. Raises: ValueError: if channels or depth don't match. """ # don't use this for now. del max_subword_length del memory_embeddings all_blocks = [] all_scores = [] tf.logging.info("GSW block layer") def _tile(x, n, tile_dim): # Simple tile function in MTF. return mtf.concat([x] * n, tile_dim.name) def _repeat(x, n, repeat_dim): # repeat function in MTF tmp_dim = mtf.Dimension("tmp", 1) expand_shape = mtf.Shape(x.shape.dims + [tmp_dim]) x = mtf.reshape(x, expand_shape) x = _tile(x, n, tmp_dim) output_shape = [] for dim in x.shape.dims: if dim.name == "tmp": continue if dim.name == repeat_dim.name: dim = mtf.Dimension(dim.name, dim.size * n) output_shape.append(dim) output_shape = mtf.Shape(output_shape) x = mtf.reshape(x, output_shape) return x def _combined_dim(dims): return mtf.Dimension(dims[0].name, mtf.Shape(dims).size) # compute all subword blocks # TODO(yitay): handle offsets to get all blocks if activation == "sigtanh": # one score for sigmoid tmp_dim = mtf.Dimension("block_score", 2) else: tmp_dim = mtf.Dimension("block_score", 1) model_dim = x.shape[-1] subword_blocks_width = [2, 3, 4] if consider_chars_as_blocks: subword_blocks_width += [1] if share_block_kernel: block_kernel_shape = mtf.Shape([model_dim, tmp_dim]) block_kernel = mtf.get_variable(x.mesh, "block_kernel", block_kernel_shape, initializer=None, dtype=context.variable_dtype) else: block_kernel = None for subword_len in subword_blocks_width: if use_block_pos_embedding: # this is turn off by default. It is meant to support cases like # parameterized pooling or other features. block_len_dim = mtf.Dimension(length_dim.name, subword_len) # TODO(vqtran): Consider other positional embeddings. block_pos_emb = sinusoid_positional_embedding_weights( context.mesh, block_len_dim, x.shape[-1], context.variable_dtype.activation_dtype) block_pos_emb = _repeat( block_pos_emb, math.ceil(length_dim.size / float(subword_len)), block_len_dim) if use_offsets: offset_space = subword_len else: offset_space = 1 for offsets in range(offset_space): if offsets > 0: xoff = mtf.shift(x, offsets, length_dim, wrap=False) if use_block_pos_embedding: block_pos_emb = mtf.shift(block_pos_emb, offsets, block_pos_emb.shape[-2], wrap=False) else: xoff = x tf.logging.info("SW len=%d offset=%d", subword_len, offsets) if length_dim.size % subword_len != 0: tf.logging.info("Not divisible by length") # add extra padding tokens pad_amt = int(subword_len) - int(length_dim.size % subword_len) kp = mtf.pad(xoff, [0, pad_amt], length_dim.name) else: kp = xoff if use_block_pos_embedding: kp += block_pos_emb bx = mtf.pool_tensor_1d( kp, pool_dim=kp.shape.get_dim_by_name("length"), reduce_fn=mtf.reduce_mean, pool_size=int(subword_len)) block_score = mtf.layers.dense(bx, [tmp_dim], use_bias=False, name="bx", reduced_dims=[model_dim], variable_dtype=None, kernel_weights=block_kernel) expand_bx = _repeat(bx, subword_len, length_dim) expand_scores = _repeat(block_score, subword_len, length_dim) if offsets > 0: # add offset. expand_bx = mtf.pad(expand_bx, [offsets, 0], length_dim.name) expand_scores = mtf.pad(expand_scores, [offsets, 0], length_dim.name) new_len = expand_bx.shape.get_dim_by_name(length_dim.name) if new_len.size < length_dim.size: pad_amt = new_len.size - length_dim.size expand_bx = mtf.pad(expand_bx, [0, pad_amt], length_dim.name) expand_scores = mtf.pad(expand_scores, [0, pad_amt], length_dim.name) elif new_len.size > length_dim.size: expand_bx = mtf.slice(expand_bx, 0, length_dim.size, length_dim.name) expand_scores = mtf.slice(expand_scores, 0, length_dim.size, length_dim.name) new_tmp_dim = mtf.Dimension("extra_dim", 1) expand_shape = mtf.Shape(expand_bx.shape.dims + [new_tmp_dim]) expand_scores_shape = mtf.Shape(expand_scores.shape.dims + [new_tmp_dim]) expand_bx = mtf.reshape(expand_bx, expand_shape) expand_scores = mtf.reshape(expand_scores, expand_scores_shape) all_blocks.append(expand_bx) all_scores.append(expand_scores) all_blocks = mtf.concat(all_blocks, new_tmp_dim.name) all_scores = mtf.concat(all_scores, new_tmp_dim.name) tf.logging.info(all_blocks) new_tmp_dim = all_blocks.shape.get_dim_by_name("extra_dim") combined_dim = _combined_dim([new_tmp_dim, tmp_dim]) block_net_shape = all_scores.shape - tmp_dim - new_tmp_dim + combined_dim block_net = mtf.reshape(all_scores, block_net_shape) if block_mixing_mode == "score_attention": tf.logging.info("Using score attention") att = mtf.einsum([block_net, block_net], reduced_dims=[new_tmp_dim]) tf.logging.info(block_net) att = mtf.softmax(att, reduced_dim=att.shape[-1]) block_net = mtf.einsum([att, block_net], output_shape=block_net.shape) tf.logging.info(block_net) if activation == "softmax": block_net = mtf.softmax(block_net, reduced_dim=new_tmp_dim) elif activation == "tanh": tf.logging.info("Using tanh") block_net = mtf.tanh(block_net) all_blocks = block_net * all_blocks all_blocks = mtf.reduce_sum(all_blocks, reduced_dim=new_tmp_dim) output = all_blocks if downsample: output_length = output.shape.get_dim_by_name("length") if output_length.size % int(downsample) != 0: pad_amt = int(downsample) - int( output_length.size % int(downsample)) output = mtf.pad(output, [0, pad_amt], output_length.name) if downsample_function == "mean": output = mtf.pool_tensor_1d( output, pool_dim=output.shape.get_dim_by_name("length"), reduce_fn=mtf.reduce_mean, pool_size=int(downsample)) else: raise ValueError("Downsampling function not implemeneted.") return output
# input1 = mtf.get_variable(mesh, "input1", [dim1,dim2]) # input2 = mtf.get_variable(mesh, "input2", [dim3,dim4]) id_hldr = mtf.import_tf_tensor(mesh, np.random.randn(batch_dim.size,field_dim.size).astype(np.int32), shape=[batch_dim,field_dim]) wt_hldr = mtf.import_tf_tensor(mesh, np.random.randn(batch_dim.size,field_dim.size).astype(np.float32), shape=[batch_dim,field_dim]) label = mtf.import_tf_tensor(mesh, np.array([1,0,1,0,1,0,1,0,1,0]).astype(np.float32), shape=[batch_dim]) result = widedeep(id_hldr,wt_hldr,vocab_dim,embed_dim,outdim) result = mtf.reduce_mean(result,reduced_dim=outdim) # label = mtf.reshape(label,new_shape=[batch_dim, outdim]) # output = mtf.layers.softmax_cross_entropy_with_logits(result, label,vocab_dim=outdim) # result = mtf.sigmoid(result) # result = -(label*mtf.log(result)+(1-label)*mtf.log(1-result)) # result = mtf.reduce_sum(result) result = mtf.layers.sigmoid_cross_entropy_with_logits(result,label) wide_loss = mtf.reduce_sum(result) print("========",global_step) devices = ["gpu:0"] mesh_shape = [("all_processors", 1)] layout_rules = [("dim1", "all_processors")] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, devices) var_grads = mtf.gradients( [wide_loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.SgdOptimizer(0.01) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh:mesh_impl})
def beam_search(self, inputs, decode_length, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, alpha=0.6, shared_params=None, encoder_layer_outputs=None): """Beam search. Args: inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim]. decode_length: an int32 mtf scalar. Maximum decode length. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor alpha: a floating point value (length bonus) shared_params: an optional dictionary encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ if not self.autoregressive: raise ValueError("must be autoregressive") batch_dims = inputs.shape.dims[:-2] if len(batch_dims) != 1: raise NotImplementedError( "beam search supports exactly one batch dimension.") beam_dim = inputs.shape.dims[-2] length_dim = inputs.shape.dims[-1] initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None context_first_part = Context( mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="first_part", autoregressive=self.autoregressive, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs) del logits # There are no partial targets. # Replace initial states by zeros to avoid computing them. initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states] constant_states = context_first_part.constant_states def logits_fn(step_num, ids, states): """logits_fn for mtf.beam_search.beam_search().""" context_incremental = Context( mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="incremental", autoregressive=self.autoregressive, position=step_num, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) inputs_this_step = mtf.gather(ids, step_num - 1, length_dim) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step) return mtf.to_float(logits), context_incremental.new_states beams, unused_scores = mtf.beam_search.beam_search( logits_fn, inputs, alpha, states=initial_states, decode_length=decode_length, use_tpu=True, dtype=tf.float32, mesh_shape=self.mesh_shape, layout=self.layout) return mtf.gather( beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32), beam_dim)
def model_fn(features, labels, mode, params): # Get global step global_step = tf.train.get_global_step() # Construct mtf graph + mesh from params graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(params["mesh_shape"]) layout_rules = mtf.convert_to_layout_rules(params["layout"]) # Mesh setup if params["use_tpu"]: var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules) else: var_placer = None gpu_ids = params["gpu_ids"] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, gpu_ids) # Trainable variable precision # Store to checkpoints in master type, train in slice type, compute in activation type if params["precision"] == "bfloat16": variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32, activation_dtype=tf.bfloat16) else: variable_dtype = mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32) # Build mtf mesh object mesh = mtf.Mesh(graph, "my_mesh", var_placer) # Build mtf_features & seq length dict for getting number of microbatches # We need to pack inputs into a dict to pass into serialize_training_step features_dict = {"inputs": features, "labels": labels} sequence_length_dict = { "inputs": params["n_ctx"], "labels": params["n_ctx"] } params = add_mode_to_params(params, mode) batch_size = get_batch_size(params) batch_dim = mtf.Dimension("batch", batch_size) batch_dims = [batch_dim] feature_length = sequence_length_dict["inputs"] length_dim = mtf.Dimension("sequence", feature_length) mtf_features = {} for key, x in features_dict.items(): if x is not None: feature_shape = mtf.Shape(batch_dims + [length_dim]) if type(features_dict[key]) == dict: features_dict[key] = features_dict[key]["feature"] x = tf.cast(features_dict[key], tf.int32) x = tf.reshape(x, feature_shape.to_integer_list) mtf_features[key] = mtf.import_fully_replicated(mesh, x, feature_shape, name=key) # Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model other_features = {} memory_length_dim = mtf.Dimension("memory_length", length_dim.size) attn_bias = biasmask_attn_weights( mesh, length_dim, memory_length_dim, variable_dtype) if params["causal"] else None # Add attn_bias into mtf_features other_features["attn_bias"] = attn_bias # Define other Dimensions that we'll need inside the model embd_dim = mtf.Dimension("embd", params["n_embd"]) vocab_dim = mtf.Dimension("vocab", params["n_vocab"]) # We need this because gathering when both the args have the same dimension in them breaks things # This dim is specifically for the weights # This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"]) other_features["embd_dim"] = embd_dim other_features["vocab_dim"] = vocab_dim other_features["embed_sequence_dim"] = embed_sequence_dim other_features["memory_length_dim"] = memory_length_dim if mode == tf.estimator.ModeKeys.PREDICT: # Set up the model for prediction inputs = mtf_features["inputs"] if params["remove_partial_sequences"] is None: params["remove_partial_sequences"] = False export = params.get("export", False) if not export: mtf_samples = sample_autoregressive( inputs, other_features=other_features, params=params, variable_dtype=variable_dtype, remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=params['sampling_use_entmax']) else: with mtf.utils.outside_all_rewrites(): with tf.variable_scope('gpt2'): mtf_samples, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None) mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"inputs": inputs, "outputs": outputs} def scaffold_fn(): return tf.train.Scaffold( local_init_op=tf.group( tf.train.Scaffold.default_local_init_op(), lowering.copy_masters_to_slices(), name="mtf_local_init_op"), ready_op=tf.concat([ tf.report_uninitialized_variables(), resources.report_uninitialized_resources() ], axis=0, name="mtf_ready_op")) return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, scaffold_fn=scaffold_fn, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) # We're not predicting, so we better be training or evaluating assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL) if mode == tf.estimator.ModeKeys.TRAIN: # Gets number of microbatches per batch for serialized training # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed num_microbatches = int( mtf_transformer.utils.serialize_num_microbatches( batch_dim=batch_dim, sequence_length=sequence_length_dict, mesh_shape=mesh_shape, layout_rules=layout_rules, tokens_per_microbatch_per_replica=params[ "tokens_per_mb_per_replica"])) else: num_microbatches = 1 params[ "num_microbatches"] = num_microbatches # Add num microbatches to params if num_microbatches > 1: # For serialize_training_step we need to modify the model to output results in a dict def serialized_fn(mtf_features): if params["model"] == "GPT": with tf.variable_scope('gpt2'): logits, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype) return { "logits": logits, "loss": loss, "loss_batch": loss_batch } else: raise Exception( f"'{params['model']}' is not a valid model - please select from [GPT]" ) # Serialize the training step - Gradients are accumulated locally and reduced once. var_grads, output_dict = mtf.serialize_training_step( mtf_features, serialized_fn, batch_dim, num_microbatches) loss = output_dict["loss"] loss_batch = output_dict["loss_batch"] logits = output_dict["logits"] else: # If we're not splitting into microbatches, return logits & loss as is if params["model"] == "GPT": with mtf.utils.outside_all_rewrites(): with tf.variable_scope('gpt2'): logits, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None) else: raise Exception( f"'{params['model']}' is not a valid model - please select from [GPT]" ) # Auto layout generation if params["auto_layout"]: auto_layout(graph, mesh_shape, logits, loss) if params["auto_layout_and_mesh_shape"]: auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss) if mode == tf.estimator.ModeKeys.TRAIN: # In TRAIN mode, get optimizer if params["num_microbatches"] > 1: # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn # So we pass them in here _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=variable_dtype, inp_var_grads=var_grads) else: # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=variable_dtype) # Log summaries to tensorboard mtf.scalar_summary("loss", loss) # Log gradients if in params if params["log_grads"] not in [None, False]: for g in var_grads: grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g))) mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm) else: # For now, we can only export fully-replicated tensors. # This has to be done before lowering or they will not be included in the graph mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim) max_logits = mtf.argmax(logits, vocab_dim) del logits fully_replicated_mean_logits = mtf.anonymize(mean_logits) fully_replicated_max_logits = mtf.anonymize(max_logits) fully_replicated_loss_batch = mtf.anonymize(loss_batch) # Gets & prints info about no. trainable vars in the model & dimension names get_graph_info(graph) # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.cast(tf_loss, tf.float32) if mode == tf.estimator.ModeKeys.TRAIN: # Use our patched version until mtf updates theirs host_call = create_host_call(params['model_path']) mtf.utils.remove_summaries() # Creates train_op tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add( global_step, 1)) # Need to manually increment global_step tf.logging.info(f"tf_update_ops: {tf_update_ops}") train_op = tf.group(tf_update_ops) else: tf_mean_logits = lowering.export_to_tf_tensor( fully_replicated_mean_logits) tf_max_logits = lowering.export_to_tf_tensor( fully_replicated_max_logits) tf_loss_batch = tf.to_float( lowering.export_to_tf_tensor(fully_replicated_loss_batch)) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: # Set up the checkpoint server and return the TPUEstimatorSpec saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( params["model_path"], save_steps=params["steps_per_checkpoint"], saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, host_call=host_call, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: # Evaluation metrics def _perplexity(loss): perplexity = tf.exp(loss) return tf.metrics.mean(perplexity) def _bits_per_byte(loss): bpb = loss * (0.29335 / math.log(2)) return tf.metrics.mean(bpb) def _metric_fn(tf_mean_logits, tf_loss_batch): mean_logits = tf.metrics.mean(tf_mean_logits) loss = tf.reduce_mean(tf_loss_batch) perp = _perplexity(loss) bpb = _bits_per_byte(loss) return { "mean_logits": mean_logits, "perplexity": perp, "bits per byte": bpb } def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch): eos_token = params["eos_id"] answer_positions = tf.where( tf.math.not_equal(labels, eos_token)) correct_answers = tf.gather_nd( tf.math.equal(tf_max_logits, labels), answer_positions) accuracy = tf.metrics.mean(tf.cast(correct_answers, tf.float32)) # I guess tf_loss_batch has z_loss and maybe other stuff added to it # so maybe this should be calculated separately in the future answer_loss = tf.gather_nd(tf_loss_batch, answer_positions) log_perplexity = tf.metrics.mean(answer_loss) return { "lambada_acc": accuracy, "lambada_log_ppl": log_perplexity } eval_task = params["eval_task"] if eval_task == "lambada": eval_metrics = (_lambada_metric_fn, [labels, tf_max_logits, tf_loss_batch]) else: eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def transformer_moe_layer_v2(inputs, output_dim, hparams, train, variable_dtype, layout=None, mesh_shape=None, nonpadding=None): """2-level mixture of experts. Adapted from the paper https://arxiv.org/abs/1701.06538 Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_num_experts: number of experts hparams.moe_hidden_size: size of hidden layer in each expert hparams.moe_group_size: size of each "group" for gating purposes hparams.moe_capacity_factor_train: a float hparams.moe_capacity_factor_eval: a float hparams.moe_capacity_factor_second_level: a float hparams.moe_gating: a string + all hyperparmeters used by _top_2_gating() One set of params for experts in first level and different of hparams per expert in the second level. The number of parameters in the gating network is: (input_dim.size * (hparams.num_experts) + (moe_hidden_size * hparams.num_experts) * hparams.num_experts The number of parameters in the experts themselves is: (hparams.num_experts * (input_dim.size + output_dim.size) * hparams.moe_hidden_size) The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting of the representations of all positions in a batch of sequences. Each position of each sequence is sent to 0-3 experts. The expert choices and the combination weights are determined by a learned gating function. This function returns a small auxiliary loss that should be added to the training loss of the model. This loss helps to balance expert usage. Without the loss, it is very likely that a few experts will be trained and the rest will starve. Several hacks are necessary to get around current TPU limitations: - To ensure static shapes, we enforce (by truncation/padding) that each sequence send the same number of elements to each expert. It would make more sense to enforce this equality over the entire batch, but due to our hacked-up gather-by-matmul implementation, we need to divide the batch into "groups". For each group, the same number of elements are sent to each expert. TODO(noam): Factor this code better. We want to be able to substitute different code for the experts themselves. Dimensions cheat sheet: a, b: batch size l: original sequence length m: input depth n: output depth g, h: number of groups s, t: group size x, y: number of experts c, d: expert capacity input: [a0, b1, l, m] input: [a0, g1, s, m] dispatch_tensor_x: [a0, g1, s, x, c] expert_input: [a0, g1, x, c, m] alltoall: [a0, g, x1, c, m] alltoall: [a0, g, x1, c, m] transpose: [x1, a0, g, c, m] reshape: [x1, h0, s, m] assignment2: [x1, h0, t, y, d] expert_input2: [x1, h0, y, d, m] alltoall: [x1, h, y0, d, m] ... reverse of that gating params 0: [m, x] gating params 1: [x1, m, y] expert params: [x1, y0, m, hidden] [x1, y0, hidden, n] Args: inputs: a mtf.Tensor with shape [a, b, l, m] output_dim: a mtf.Dimension (for Transformer, this is input_dim) hparams: model hyperparameters train: a boolean variable_dtype: a mtf.VariableDType layout: optional - an input to mtf.convert_to_layout_rules mesh_shape: optional - an input to mtf.convert_to_shape nonpadding: an optional mtf.Tensor with shape [a, b, l] and the same dtype as inputs, consisting of ones(nonpadding) and zeros(padding). Returns: outputs: a Tensor with shape [a, b, l, n] loss: a mtf scalar Raises: ValueError: on unrecognized hparams.moe_gating """ if nonpadding is not None: nonpadding = mtf.zeros(inputs.mesh, inputs.shape.dims[:-1], dtype=inputs.dtype) + nonpadding insert_outer_batch_dim = (len(inputs.shape.dims) == 3) if insert_outer_batch_dim: inputs = mtf.reshape(inputs, [mtf.Dimension("outer_batch", 1)] + inputs.shape.dims) assert len(hparams.moe_num_experts) == 2 a0, b1, l, m = inputs.shape.dims hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size) x1 = mtf.Dimension("expert_x", hparams.moe_num_experts[0]) y0 = mtf.Dimension("expert_y", hparams.moe_num_experts[1]) x = mtf.Dimension("expert_x_unsplit", hparams.moe_num_experts[0]) y = mtf.Dimension("expert_y_unsplit", hparams.moe_num_experts[1]) n = output_dim # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups (g.size) is a multiple of the mesh dimension # over which those groups are split. num_groups, group_size = _split_into_groups( b1.size * l.size, hparams.moe_group_size, mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, b1)) g1 = mtf.Dimension(b1.name, num_groups) g = mtf.Dimension(b1.name + "_unsplit", g1.size) s = mtf.Dimension("group_size_x", group_size) # Each sequence sends (at most?) expert_capacity positions to each expert. # Static expert_capacity dimension is needed for expert batch sizes if train: capacity_factor = hparams.moe_capacity_factor_train else: capacity_factor = hparams.moe_capacity_factor_eval expert_capacity = min(s.size, int((s.size * capacity_factor) / x.size)) expert_capacity = max(expert_capacity, 4) c = mtf.Dimension("expert_capacity_x", expert_capacity) # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups (h.size) is a multiple of the mesh dimension # over which those groups are split. num_groups, group_size = _split_into_groups( a0.size * g.size * c.size, hparams.moe_group_size, mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, a0)) t = mtf.Dimension("group_size_y", group_size) h0 = mtf.Dimension(a0.name, num_groups) h = mtf.Dimension(a0.name + "_unsplit", h0.size) expert_capacity = min( t.size, int((t.size * hparams.moe_capacity_factor_second_level) / y.size)) expert_capacity = max(expert_capacity, 4) d = mtf.Dimension("expert_capacity_y", expert_capacity) # First level of expert routing # Reshape the inner batch size to a multiple of group_dim g1 and # group_size_dim s. inputs = mtf.reshape(inputs, [a0, g1, s, m]) if nonpadding is not None: nonpadding = mtf.reshape(nonpadding, [a0, g1, s]) # Get the assignments for the first level. # dispatch_tensor_x has shape [a0, g1, s, x, c] if hparams.moe_gating == "top_2": dispatch_tensor_x, combine_tensor_x, loss_outer = _top_2_gating( inputs=inputs, outer_expert_dims=None, experts_dim=x, expert_capacity_dim=c, hparams=hparams, train=train, variable_dtype=variable_dtype, name="outer_gating", importance=nonpadding) else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) # Now create expert_inputs based on the assignments. # put num_experts dimension first to make split easier in alltoall expert_inputs_x = mtf.einsum([inputs, dispatch_tensor_x], [x, a0, g1, c, m]) # we construct an "importance" Tensor for the inputs to the second-level # gating. The importance of an input is 1.0 if it represents the # first-choice expert-group and 0.5 if it represents the second-choice expert # group. This is used by the second-level gating. importance = mtf.reduce_sum(combine_tensor_x, output_shape=[x, a0, g1, c]) importance = 0.5 * (mtf.to_float(mtf.greater(importance, 0.5)) + mtf.to_float(mtf.greater(importance, 0.0))) # First level, all to all. Here we change the split dimension from g1 to x1. expert_inputs_x = mtf.reshape(expert_inputs_x, mtf.Shape([x1, a0, g, c, m])) importance = mtf.reshape(importance, [x1, a0, g, c]) # Second level of expert routing # Reshape the expert_inputs outer batch dim to be a multiple of group_dim h0 # and group_size_dim t. inputs_y = mtf.reshape(expert_inputs_x, [x1, h0, t, m]) importance = mtf.reshape(importance, [x1, h0, t]) # Get the assignments for the second level. # dispatch_tensor_y has shape [x1, h0, t, y, d] if hparams.moe_gating == "top_2": dispatch_tensor_y, combine_tensor_y, loss_inner = _top_2_gating( inputs=inputs_y, outer_expert_dims=[x1], experts_dim=y, expert_capacity_dim=d, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=importance, name="inner_gating") else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) # Now create expert_inputs based on the assignments. # put num_experts dimension first to make split easier in alltoall expert_inputs_y = mtf.einsum([inputs_y, dispatch_tensor_y], [y, x1, h0, d, m]) # Second level, all to all. Here we change the split dimension from h0 to y0. expert_inputs_y = mtf.reshape(expert_inputs_y, mtf.Shape([y0, x1, h, d, m])) hidden_output = mtf.layers.dense(expert_inputs_y, hidden_dim, expert_dims=[y0, x1], activation=mtf.relu, use_bias=False, variable_dtype=variable_dtype, name="wi") expert_output = mtf.layers.dense(hidden_output, output_dim, expert_dims=[y0, x1], use_bias=False, variable_dtype=variable_dtype, name="wo") # NOW COMBINE EXPERT OUTPUTS (reversing everything we have done) # expert_output has shape [y0, x1, h, d, n] # alltoall expert_output = mtf.reshape(expert_output, mtf.Shape([y, x1, h0, d, n])) # combine results from inner level output_y = mtf.einsum([expert_output, combine_tensor_y], [x1, h0, t, n]) # Reshape the combined tensor from inner level to now contain outer_batch_dim # a0 and group_dim g output = mtf.reshape(output_y, [x1, a0, g, c, n]) # alltoall from expert_dim x to group_dim g1 expert_output_x = mtf.reshape(output, mtf.Shape([x, a0, g1, c, n])) # combine results from outer level output_x = mtf.einsum([expert_output_x, combine_tensor_x], [a0, g1, s, n]) # Reshape the combined tensor to now contain inner_batch_dim # b1 and the original sequence length output = mtf.reshape(output_x, [a0, b1, l, n]) if insert_outer_batch_dim: output = mtf.reshape(output, [b1, l, n]) return output, (loss_outer + loss_inner) * hparams.moe_loss_coef
def compute_loss(logits, positions): one_hot_positions = mtf.one_hot(positions, output_dim=seq_dim) log_probs = mtf.log_softmax(logits, seq_dim) loss = -mtf.reduce_mean( mtf.reduce_sum(one_hot_positions * log_probs, reduced_dim=seq_dim)) return loss
def recon_model(mesh, data, R0, x0, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print(dtype, npdtype) # Compute a few things first, using simple tensorflow stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] # Begin simulation if x0 is None: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.random_normal_initializer( mean=0.0, stddev=1, seed=None)) else: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.constant_initializer(x0)) print("\nfieldvar : \n", fieldvar) # Here we can run our nbody if FLAGS.nbody: state = mtfpm.lpt_init_single( fieldvar, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # Here we can run our nbody final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr, halo_size) else: final_state = mtfpm.lpt_init_single( fieldvar, stages[-1], kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=dtype, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) mtfdata = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(data), shape=shape) # Get prior k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 #*nc**3 # Total loss diff = (final_field - mtfdata) R0 = tf.constant(R0) print("R0 in the recon_model : ", R0) def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts # Element-wise function that applies a Fourier kernel plambda = FLAGS.plambda def _cwise_logprob(finalfield, data): galmean = tfp.distributions.Poisson(rate=plambda * (1 + finalfield)) logprob = galmean.log_prob(data) return -1 * logprob cfield = mesh_utils.r2c3d(final_field, k_dims_pr, dtype=cdtype) cfield = mtf.cwise(_cwise_smooth, [cfield] + kv, output_dtype=cdtype) final_fieldsm = mesh_utils.c2r3d(cfield, diff.shape[-3:], dtype=dtype) chisq = mtf.cwise(_cwise_logprob, [final_fieldsm, mtfdata], output_dtype=tf.float32) # chisq = mtf.reduce_sum(chisq) ## # loss = chisq + prior def _cwise_sample(finalfield, data): galmean = tfp.distributions.Poisson(rate=plambda * (1 + finalfield)) sample = galmean.sample() return sample sample = mtf.cwise(_cwise_sample, [final_fieldsm, mtfdata], output_dtype=tf.float32) # fields = [fieldvar, sample] metrics = [chisq, prior, loss] return fields, metrics, kv
def _top_2_gating(inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, variable_dtype, importance=None, name="top_2_gating"): """Compute gating for mixture-of-experts in TensorFlow. Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_use_second_place_loss: a boolean hparams.moe_second_policy_train: a string hparams.moe_second_policy_eval: a string hparams.moe_second_threshold: a float The returned forward assignment is a tensor used to map (via einsum) from the inputs to the expert_inputs. Likewise, the returned combine_tensor is used to map (via einsum) from the expert outputs to the outputs. Both the forward and backward assignments are mostly zeros. The shapes of the tensors are as follows. inputs: [<batch_dims>, group_size_dim, input_dim] importance: [<batch_dims>, group_size_dim] dispatch_tensor: [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] expert_inputs: [<batch_dims>, experts_dim, expert_capacity_dim, input_dim] expert_outputs: [<batch_dims>, experts_dim, expert_capacity_dim, output_dim] combine_tensor: [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] outputs: [<batch_dims>, group_size_dim, output_dim] "importance" is an optional tensor with one floating-point value for each input vector. If the importance of an input is 1.0, then we send it to up to 2 experts. If 0.0 < importance < 1.0, then we send it to at most one expert. If importance == 0.0, then we send it to no experts. We use "importance" at the second-level gating function of a hierarchical mixture of experts. Inputs to the first-choice expert-group get importance 1.0. Inputs to the second-choice expert group get importance 0.5. Inputs that represent padding get importance 0.0. Args: inputs: a mtf.Tensor with shape [<batch_dims>, group_size_dim, input_dim] outer_expert_dims: an optional list of dimensions. This is for the case where we are at an inner level of a hierarchical MoE. experts_dim: a Dimension (the number of experts) expert_capacity_dim: a Dimension (number of examples per group per expert) hparams: model hyperparameters. train: a boolean variable_dtype: a mtf.VariableDType importance: an optional tensor with shape [<batch_dims>, group_size_dim] name: an optional string Returns: dispatch_tensor: a Tensor with shape [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] combine_tensor: a Tensor with shape [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] loss: a mtf scalar Raises: ValueError: on illegal hyperparameters """ group_size_dim, unused_input_dim = inputs.shape.dims[-2:] raw_gates = mtf.layers.dense(inputs, experts_dim, use_bias=False, expert_dims=outer_expert_dims, variable_dtype=variable_dtype, name=name) raw_gates = mtf.softmax(raw_gates, experts_dim) # The internals of this function run in float32. # bfloat16 seems to reduce quality. raw_gates = mtf.to_float(raw_gates) expert_capacity_f = float(expert_capacity_dim.size) # FIND TOP 2 EXPERTS PER POSITON # Find the top expert for each position. shape=[batch, group] index_1, gate_1 = mtf.top_1(raw_gates, experts_dim) # [batch, group, experts] mask_1 = mtf.one_hot(index_1, experts_dim, dtype=raw_gates.dtype) density_1_proxy = raw_gates if importance is not None: mask_1 *= mtf.to_float(mtf.equal(importance, 1.0)) gate_1 *= mtf.to_float(mtf.equal(importance, 1.0)) density_1_proxy *= mtf.to_float(mtf.equal(importance, 1.0)) gates_without_top_1 = raw_gates * (1.0 - mask_1) # [batch, group] index_2, gate_2 = mtf.top_1(gates_without_top_1, experts_dim) # [batch, group, experts] mask_2 = mtf.one_hot(index_2, experts_dim, dtype=raw_gates.dtype) if importance is not None: mask_2 *= mtf.to_float(mtf.greater(importance, 0.0)) denom = gate_1 + gate_2 + 1e-9 gate_1 /= denom gate_2 /= denom # BALANCING LOSSES # shape = [batch, experts] # We want to equalize the fraction of the batch assigned to each expert density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_size_dim) # Something continuous that is correlated with what we want to equalize. density_1_proxy = mtf.reduce_mean(density_1_proxy, reduced_dim=group_size_dim) loss = (mtf.reduce_mean(density_1_proxy * density_1) * float(experts_dim.size * experts_dim.size)) if hparams.moe_use_second_place_loss: # Also add a loss to encourage all experts to be used equally also as the # second-place expert. Experimentally, this seems to be a wash. # We want to equalize the fraction of the batch assigned to each expert: density_2 = mtf.reduce_mean(mask_2, reduced_dim=group_size_dim) # As a proxy for density_2, we renormalize the raw gates after the top one # has been removed. normalized = gates_without_top_1 / (mtf.reduce_sum( gates_without_top_1, reduced_dim=experts_dim) + 1e-9) density_2_proxy = mtf.reduce_mean(normalized, reduced_dim=group_size_dim) loss_2 = (mtf.reduce_mean(density_2_proxy * density_2) * float(experts_dim.size * experts_dim.size)) loss += loss_2 * 0.5 # Depending on the policy in the hparams, we may drop out some of the # second-place experts. if train: policy = hparams.moe_second_policy_train threshold = hparams.moe_second_threshold_train else: policy = hparams.moe_second_policy_eval threshold = hparams.moe_second_threshold_eval if policy == "all": # Use second-place experts for all examples. pass elif policy == "none": # Never use second-place experts for all examples. mask_2 = mtf.zeros_like(mask_2) elif policy == "threshold": # Use second-place experts if gate_2 > threshold. mask_2 *= mtf.to_float(mtf.greater(gate_2, threshold)) elif policy == "random": # Use second-place experts with probablity min(1.0, gate_2 / threshold). mask_2 *= mtf.to_float( mtf.less(mtf.random_uniform(gate_2.mesh, gate_2.shape), gate_2 / max(threshold, 1e-9))) else: raise ValueError("Unknown policy %s" % policy) # COMPUTE ASSIGNMENT TO EXPERTS # [batch, group, experts] # This is the position within the expert's mini-batch for this sequence position_in_expert_1 = mtf.cumsum(mask_1, group_size_dim, exclusive=True) * mask_1 # Remove the elements that don't fit. [batch, group, experts] mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f)) # [batch, experts] # How many examples in this sequence go to this expert mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_size_dim) # [batch, group] - mostly ones, but zeros where something didn't fit mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim) # [batch, group] position_in_expert_1 = mtf.reduce_sum(position_in_expert_1, reduced_dim=experts_dim) # Weight assigned to first expert. [batch, group] gate_1 *= mask_1_flat # [batch, group, experts] position_in_expert_2 = ( mtf.cumsum(mask_2, group_size_dim, exclusive=True) + mask_1_count) position_in_expert_2 *= mask_2 mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f)) # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) gate_2 *= mask_2_flat position_in_expert_2 = mtf.reduce_sum(position_in_expert_2, reduced_dim=experts_dim) # [batch, group, experts, expert_capacity] combine_tensor = ( gate_1 * mask_1_flat * mtf.one_hot(index_1, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) + gate_2 * mask_2_flat * mtf.one_hot(index_2, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim)) combine_tensor = mtf.cast(combine_tensor, inputs.dtype) loss = mtf.cast(loss, inputs.dtype) dispatch_tensor = mtf.cast(mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype) return dispatch_tensor, combine_tensor, loss
def recon_model(mesh, data, bparams, ipkerror, R0, x0, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ b1, b2, bs2 = bparams kerror, perror = ipkerror[0].astype(np.float32), ipkerror[1].astype( np.float32) if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print("Dtype : ", dtype, npdtype) # Compute a few things first, using simple tensorflow kny = 1 * np.pi * nc / bs R1, R2 = 3., 3 * 1.2 stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition scalar = mtf.Dimension("scalar", 1) fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) #k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('..//data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('..//data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) pke_dim = mtf.Dimension("epk", len(perror)) pkerror = mtf.import_tf_tensor(mesh, perror.astype(npdtype), shape=[pke_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] splittables = lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3] # # Begin simulation if x0 is None: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.random_normal_initializer( mean=0.0, stddev=1, seed=None)) else: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.constant_initializer(x0)) state = mtfpm.lpt_init_single( fieldvar, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr, halo_size) # paint the field final_field = mtf.zeros(mesh, shape=part_shape) final_field = mcomp.cic_paint_fr(final_field, final_state, part_shape, hr_shape, halo_size, splittables, mesh) ## #Get the fields for bias hr_field = mcomp.fr_to_hr(final_field, hr_shape, halo_size, splittables, mesh) mstate = mpm.mtf_indices(hr_field.mesh, shape=part_shape[1:], dtype=tf.float32) X = mtf.einsum([mtf.ones(hr_field.mesh, [batch_dim]), mstate], output_shape=[batch_dim] + mstate.shape[:]) k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] tfnc, tfbs = cswisef.float_to_mtf(nc * 1., mesh, scalar), cswisef.float_to_mtf( bs, mesh, scalar) # initc = fieldvar d0 = initc - mtf.reduce_mean(initc) # d2 = initc * initc d2 = d2 - mtf.reduce_mean(d2) # cfield = mesh_utils.r2c3d(d0, k_dims_pr, dtype=cdtype) shearfield = mtf.zeros(mesh, shape=part_shape) shearfield = shear(shearfield, cfield, kv, tfnc, tfbs) s2 = shearfield - mtf.reduce_mean(shearfield) dread = mcomp.cic_readout_fr(d0, [X], hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) d2read = mcomp.cic_readout_fr(d2, [X], hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) s2read = mcomp.cic_readout_fr(s2, [X], hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) ed, ed2, es2 = mtf.zeros(mesh, shape=part_shape), mtf.zeros( mesh, shape=part_shape), mtf.zeros(mesh, shape=part_shape) ed = mcomp.cic_paint_fr(ed, final_state, output_shape=part_shape, hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh, weights=dread) ed2 = mcomp.cic_paint_fr(ed2, final_state, output_shape=part_shape, hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh, weights=d2read) es2 = mcomp.cic_paint_fr(es2, final_state, output_shape=part_shape, hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh, weights=s2read) model = ed * b1 + ed2 * b2 + es2 * bs2 mtfdata = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(data), shape=shape) diff = model - mtfdata # Get prior k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 #* nc**3 # Total loss cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype) def _cwise_diff(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid(x=kk, x_ref_min=kerror.min(), x_ref_max=kerror.max(), y_ref=pk) priormesh = tf.reshape(pkmesh, kshape) priormesh = tf.cast(priormesh**0.5, kfield.dtype) return kfield / priormesh cdiff = mtf.cwise(_cwise_diff, [cdiff, pkerror] + kv, output_dtype=cdtype) def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype) diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype) chisq = mtf.reduce_sum(mtf.square(diff)) loss = chisq + prior fields = [fieldvar, final_field, model] metrics = [chisq, prior, loss] return fields, metrics, kv
def sample_autoregressive(self, partial_sequences, stop_at_token=1, max_steps=None, temperature=1.0, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, shared_params=None, has_partial_sequences=True, encoder_layer_outputs=None): """Sample randomly one token at a time. The partial_sequences represent partial sequences to be continued. The first tokens of each sequence are nonzero representing the given partial sequences and the last tokens of each sequence are zeros, representing what needs to be filled in. If there are no partial sequences (you want to sample from the beginning), then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and has_partial_sequences=False (so we can skip computation). Args: partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim] stop_at_token: an optional integer eos id. Stop when we produce it. max_steps: an optional integer temperature: an optional floating point value between 0.0 and 1.0 0.0 means argmax, 1.0 means sample according to predicted distribution. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor shared_params: an optional dictionary has_partial_sequences: a boolean encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: a Tensor with shape [<batch_dims>, length_dim] """ del max_steps # TODO(noam): implement if not self.autoregressive: raise ValueError("must be autoregressive") inputs = partial_sequences batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None context_first_part = Context( mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="first_part", autoregressive=self.autoregressive, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs) del logits constant_states = context_first_part.constant_states if not has_partial_sequences: initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states] else: initial_states = context_first_part.new_states def cond_fn(position, ids, *unused_states): """Should we run another loop iteration.""" past_end = mtf.greater_equal(position, length_dim.size) is_done = past_end if stop_at_token is not None: has_eos = mtf.reduce_any( mtf.equal(ids, stop_at_token), reduced_dim=length_dim) is_done = mtf.logical_or(is_done, has_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done) def body_fn(position, ids, *states): """One step in the decode loop.""" context_incremental = Context( mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="incremental", autoregressive=self.autoregressive, position=position, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) inputs_this_step = mtf.gather(ids, position - 1, length_dim) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step) ids_this_step = mtf.sample_with_temperature( logits, self.output_vocab_dim, temperature) new_position = position + 1 new_ids = ids + ids_this_step * mtf.one_hot( position, length_dim, dtype=tf.int32) return [new_position, new_ids] + context_incremental.new_states while_loop_inputs = [initial_position, inputs] + initial_states final_position, outputs = mtf.while_loop( cond_fn, body_fn, while_loop_inputs)[:2] del final_position return outputs
def sample_autoregressive( partial_sequences, other_features, params, stop_at_token=50256, max_steps=None, temperature=0.9, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, shared_params=None, has_partial_sequences=True, encoder_layer_outputs=None, never_end=False, remove_partial_sequences=False, sampling_keep_top_k=-1, bos_id=50256, ): """Sample randomly one token at a time. The partial_sequences represent partial sequences to be continued. The first tokens of each sequence are nonzero representing the given partial sequences and the last tokens of each sequence are zeros, representing what needs to be filled in. If there are no partial sequences (you want to sample from the beginning), then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and has_partial_sequences=False (so we can skip computation). Args: partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim] stop_at_token: an optional integer eos id. Stop when we produce it. max_steps: an optional integer, the max number of steps to decode. temperature: an optional floating point value between 0.0 and 1.0 0.0 means argmax, 1.0 means sample according to predicted distribution. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor shared_params: an optional dictionary has_partial_sequences: a boolean encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer never_end: a boolean - if set, then avoid generating stop_at_token remove_partial_sequences: a boolean - whether to remove the partial sequences from the output sampling_keep_top_k: an integer - if not -1, only sample from the top k logits. bos_id: beginning of sequence id Returns: a Tensor with shape [<batch_dims>, length_dim] """ inputs = partial_sequences # Partial sequences to fill in batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] padding_id = params.get("padding_id", 0) slow_sampling = params.get("slow_sampling", False) initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, padding_id)), reduced_dim=length_dim) # Gets position where zero padding starts length_range = mtf.range(inputs.mesh, length_dim, tf.int32) input_full_attention = True # for now hardcode this to true bc lazy if input_full_attention: # Vanilla autoregressive model - each position can see previous positions. # Think this feeds in to the loop fn and tells each position where it can attend to? read_priority = write_priority = length_range * mtf.to_int32( mtf.greater(length_range, initial_position)) else: read_priority = write_priority = length_range # Builds context to pass around internally # The 'first part' context records initial states of k / v / x if not slow_sampling: context_first_part = mtf_transformer.transformer.Context( model=None, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="first_part", position=length_range, position_is_default=True, new_states=[], initial_position=initial_position, sequence_id=None, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) with tf.variable_scope("gpt2"): logits, _, _ = gpt2.model({"inputs": inputs}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context_first_part) if not has_partial_sequences: initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states ] else: initial_states = context_first_part.new_states else: initial_states = [] if not has_partial_sequences: partial_sequences_eos_count = 0 if stop_at_token is not None: partial_sequences_eos_count = mtf.reduce_sum(mtf.to_int32( mtf.equal(partial_sequences, stop_at_token)), reduced_dim=length_dim) def cond_fn(position, ids, *unused_states): """Should we run another loop iteration?""" past_end = mtf.greater_equal(position, length_dim.size) if max_steps: past_end = mtf.logical_or( past_end, mtf.greater_equal(position - initial_position, max_steps)) is_done = past_end if stop_at_token is not None: eos_count = mtf.reduce_sum(mtf.to_int32( mtf.equal(ids, stop_at_token)), reduced_dim=length_dim) has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) is_done = mtf.logical_or(is_done, has_additional_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done) def body_fn(position, ids, *states): """One step in the decode loop.""" nonlocal sampling_keep_top_k context = mtf_transformer.transformer.Context( model=None, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=position, position_is_default=True, states=states, new_states=[], initial_position=position, sequence_id=None, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=ids, encoder_inputs=encoder_inputs) if not slow_sampling else None with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE): logits, _, _ = gpt2.model({"inputs": ids}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context) # By default, do top_k sampling of 0.9 if sampling_keep_top_k == -2: sampling_keep_top_k = int(logits.shape[-1].size * 0.1) if sampling_keep_top_k != -1: if sampling_keep_top_k <= 0: raise ValueError( "sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, reduced_dim=other_features["vocab_dim"]) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits) * -1e6, logits) ids_this_step = mtf.sample_with_temperature( logits, other_features["vocab_dim"], temperature) if slow_sampling: ids_this_step = mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False) else: ids_this_step = mtf.reshape(ids_this_step, (batch_dims)) one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32) one_new_id = ids_this_step * one_hot new_ids = (1 - one_hot) * ids + one_new_id new_position = position + 1 ret = [new_position, new_ids] if context is not None: ret += context.new_states return ret while_loop_inputs = [initial_position, inputs] + initial_states final_position, outputs = mtf.while_loop(cond_fn, body_fn, while_loop_inputs)[:2] del final_position if has_partial_sequences and remove_partial_sequences: # Remove partial sequences from outputs partial_length = mtf.reduce_sum(mtf.to_int32( mtf.not_equal(partial_sequences, padding_id)), reduced_dim=length_dim) outputs = mtf.dynamic_shift(outputs, -partial_length, length_dim, wrap=False) return outputs
def decode(self, inputs, variable_dtype=mtf.VariableDType(tf.float32), beam_size=1, alpha=0.6, temperature=0.0, decode_length_multiplier=1.5, decode_length_constant=10): """Sampling or beam search. TODO(noam): should we make the output length dimension different from the input length dimension? Args: inputs: a Tensor with shape [<batch_dims>, beam_dim, length_dim] variable_dtype: a mtf.VariableDType beam_size: an integer >= 1 alpha: a floating point value (length bonus for beam search) temperature: a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. decode_length_multiplier: a float decode_length_constant: a float Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ encoder_layer_outputs = [] shared_params = self._shared_params(inputs.mesh, variable_dtype) encoder_sequence_id = mtf.minimum(inputs, 1) encoder_output, encoder_loss = self.encoder.call_simple( inputs=inputs, targets=None, compute_loss=False, mode=tf.estimator.ModeKeys.PREDICT, variable_dtype=variable_dtype, sequence_id=encoder_sequence_id, shared_params=shared_params, layer_outputs=encoder_layer_outputs) del encoder_loss encoder_output = mtf.layers.rename_length_to_memory_length(encoder_output) encoder_sequence_id = mtf.layers.rename_length_to_memory_length( encoder_sequence_id) if beam_size == 1: ids_shape = inputs.shape partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) return self.decoder.sample_autoregressive( partial_sequences, temperature=temperature, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, has_partial_sequences=False, encoder_layer_outputs=encoder_layer_outputs) else: if temperature != 0: raise ValueError( "don't know how to beam search with nonzero temperature") # beam search beam_dim = mtf.Dimension("beam", beam_size) batch_dims = inputs.shape[:-1] length_dim = inputs.shape[-1] ids_shape = mtf.Shape(batch_dims + [beam_dim, length_dim]) partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * decode_length_multiplier + decode_length_constant, tf.int32) return self.decoder.beam_search( partial_sequences, decode_length, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, alpha=alpha, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs)
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "encdec": inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad( inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length( inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = ( mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num, layer_type in enumerate(hparams.decoder_layers): if layer_type == "enc_att": with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num): q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.master_dtype, self.slice_dtype, self.activation_dtype) k = mtf.einsum( [encoder_output, k_var], mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim])) v = mtf.einsum( [encoder_output, v_var], mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim])) encdec_tensors.append((q_var, o_var, k, v)) else: encdec_tensors.append(None) partial_targets = None elif hparams.transformer_type == "decoder": encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) else: raise ValueError( "hparams.model_type = %s not yet supported" % hparams.transformer_type) local_attention_window = mtf.Dimension( "local_attention_window", hparams.local_attention_window_size) if hparams.beam_size == 1: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape(self.batch_dims + [self.heads_dim, local_attention_window, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape(self.batch_dims + [beam_dim, self.heads_dim, local_attention_window, self.kv_dim]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_states = [] for layer in hparams.decoder_layers: if layer == "att": initial_states.extend( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2) elif layer == "local_att": initial_states.extend( [mtf.zeros(mesh, local_kv_shape, dtype=self.activation_dtype)] * 2) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_states = self._layer_stack( x, hparams.decoder_layers, encdec_attention_mask=encoder_attention_mask, step_num=step_num, encdec_tensors=encdec_tensors, states=states) logits = mtf.matmul(x, softmax_var) return logits, new_states if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf.beam_search.greedy_decode( logits_fn, initial_ids, temperature=temperature, initial_states=initial_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if hparams.transformer_type == "encdec": input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf.beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_states, decode_length=decode_length, use_tpu=hparams.use_tpu, dtype=self.activation_dtype) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)