def _sigmoid_tree(self, tensor): """Create probability distribution along gates dim using a sigmoid tree.""" gamma = mtf.split( mtf.sigmoid(tensor), self._pre_gates_dim, self._pre_gates_dim.size) return mtf.concat([ gamma[0] * gamma[1], gamma[0] * (1 - gamma[1]), (1 - gamma[0]) * gamma[2], (1 - gamma[0]) * (1 - gamma[2]), ], self._gates_dim.name)
def hidden_to_logits(self, hidden, context): # Each cluster returns the logits for only the tokens with itself, so their # concatenation is the full logits. return mtf.concat( [ cluster.hidden_to_logits(hidden, context=context) for cluster in self._clusters ], concat_dim_name=self._vocab_dim.name, )
def Concat(tsr_lst, name=None): assert all(tsr_lst[0].shape[:-1] == t.shape[:-1] for t in tsr_lst[1:]) concat_dim_name = utils.RandName() concat_tsrs = [] for t in tsr_lst: assert not t.shape[-1].name.startswith('axis') t = mt.rename_dimension(t, t.shape[-1].name, concat_dim_name) concat_tsrs.append(t) return mtf.concat(concat_tsrs, concat_dim_name, name)
def hidden_to_logits(self, hidden: mtf.Tensor, context: transformer.Context) -> mtf.Tensor: """Function called by mtf transformer to get the logits. Args: hidden: an mtf.Tensor, hidden model states of the final decoder layer. context: a transformer.Context, the context used for the call to the transformer. Returns: An mtf.Tensor, the logits. """ hidden *= self._output_dim.size**-0.5 component_contexts = mtf.einsum([ mtf.rename_dimension(hidden, self._output_dim.name, self._copy_output_dim.name), self._context_weights, ], reduced_dims=[self._copy_output_dim]) component_contexts = mtf.tanh(component_contexts + self._context_weights_bias) component_logits = mtf.einsum( [component_contexts, self._embedding_weights], reduced_dims=[self._output_dim]) component_logits = self._dropout(component_logits, context) prior_tanh = mtf.tanh( mtf.einsum([self._prior_weights, hidden], reduced_dims=[self._output_dim]) + self._prior_weights_bias) prior_tanh = self._dropout(prior_tanh, context) prior_shared_logits = mtf.einsum([self._prior_gates_vector, hidden], reduced_dims=[self._output_dim]) prior_frequent_vocab_logits = ( mtf.einsum([self._prior_vocab_vector, prior_tanh]) + prior_shared_logits + self._prior_bias) prior_logits = mtf.concat([ prior_frequent_vocab_logits, mtf.ones(self._mesh, mtf.Shape([self._rare_vocab_dim]), dtype=prior_shared_logits.dtype) * prior_shared_logits ], self._vocab_dim.name) if context.train and self._noise_std_dev != 0.0: prior_logits += mtf.random_normal(self._mesh, prior_logits.shape, stddev=self._noise_std_dev) prior_proportions = self._sigmoid_tree(prior_logits) logits = mtf.einsum([component_logits, prior_proportions], reduced_dims=[self._gates_dim]) return self._rearrange_sentinels(logits)
def add_position_timing_signal_func(self, context, x, step): """Add n-dimensional embedding as the position (horizontal) timing signal. Args: context: mtf context x: a tensor with shape [batch, length, depth] step: step Returns: a Tensor with the same shape as x. """ if not self.position_start_index: index = 0 elif self.position_start_index == "random": # Shift all positions randomly # TODO(dehghani): What would be reasonable for max number of shift? index = mtf.random_uniform(context.mesh, [], maxval=x.shape.dims[1].size, dtype=tf.int32) elif self.position_start_index == "step": # Shift positions based on the step if self.recurrence_type == "act": num_steps = self.act_max_steps else: num_steps = self.num_rec_steps index = mtf.cast(x.shape.dims[1].size * step / num_steps, dtype=tf.int32) length = context.length_dim channels = context.model.model_dim signal = self.get_timing_signal_1d(context, length, channels, start_index=index) if self.add_or_concat_timing_signal == "add": x_with_timing = x + mtf.cast(signal, x.dtype) # Unimplemented if self.add_or_concat_timing_signal == "concat": batch_dim = x.shape.dims[0] out_shape = mtf.Shape([batch_dim] + signal.shape.dims[1:]) signal_tiled = mtf.broadcast(signal, out_shape) x_with_timing = mtf.concat( (x, signal_tiled), concat_dim_name=signal_tiled.dimension_names[-1]) return x_with_timing
def memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh): """memory / key values from all attention paper""" dim_mem_kv = mtf.Dimension("mem_kv_sequence", num_mem_kv) emb_dim = k.shape[-1] mem_std = 1 / math.sqrt(emb_dim.size) mem_k = mtf.get_variable( mesh, "mem_k", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]), initializer=tf.random_normal_initializer(stddev=mem_std), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype, ) mem_v = mtf.get_variable( mesh, "mem_v", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]), initializer=tf.random_normal_initializer(stddev=mem_std), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) mem_k, mem_v = map( lambda t: mtf.broadcast(t, [dim_batch, dim_mem_kv, dim_heads, emb_dim] ), (mem_k, mem_v)) mem_k, mem_v = map( lambda t: mtf.rename_dimension(t, "mem_kv_sequence", "sequence"), (mem_k, mem_v)) k = mtf.concat([mem_k, k], "sequence") v = mtf.concat([mem_v, v], "sequence") return k, v
def _hidden_to_logits(self, hidden, context): """Actually compute the logits over the entire vocab.""" head_size = self._head_cluster.end_token_id # Note that computing the log softmax is equivalent to computing the logits. head_log_softmax = self._head_cluster.compute_log_softmax(hidden, context) logits = [ self._head_cluster.get_log_softmax_prefix(head_log_softmax, head_size) ] for i, cluster in enumerate(self._tail_clusters): tail_log_softmax = cluster.compute_log_softmax(hidden, context) cluster_softmax = self._head_cluster.get_log_softmax_value( head_log_softmax, head_size + i) logits.append(cluster_softmax + tail_log_softmax) return mtf.concat(logits, concat_dim_name=self._vocab_dim.name)
def unet_with_spatial_partition(mesh, mesh_impl, dataset_str, images, labels): """Builds the UNet model graph, train op and eval metrics. Args: mesh: a MeshTensorflow.mesh object. mesh_impl: a mesh implementation, such as SimdMeshImpl and PlacementMeshImpl. dataset_str: a string of either train or eval. This is used for batch_norm. images: a laid out Tensor with shape [batch, x, y, num_channels] or [batch, x, y, z, num_channels]. labels: a laid out Tensor with shape [batch, x, y, num_classes] or [batch, x, y, z, num_classes]. Returns: Prediction and loss. """ is_training = (dataset_str == 'train') if dataset_str == 'train': batch_dim = mtf.Dimension('batch', FLAGS.batch_size_train) else: assert dataset_str == 'eval' batch_dim = mtf.Dimension('batch', FLAGS.batch_size_eval) image_nx_dim = mtf.Dimension('image_nx_block', FLAGS.image_nx_block) image_ny_dim = mtf.Dimension('image_ny_block', FLAGS.image_ny_block) image_sx_dim = mtf.Dimension('image_sx_block', FLAGS.ct_resolution // FLAGS.image_nx_block) image_sy_dim = mtf.Dimension('image_sy_block', FLAGS.ct_resolution // FLAGS.image_ny_block) image_sz_dim = mtf.Dimension('image_sz_block', FLAGS.ct_resolution) image_c_dim = mtf.Dimension('image_c', FLAGS.image_c) label_c_dim = mtf.Dimension('label_c', FLAGS.label_c) mtf_images_shape, mtf_labels_shape = get_input_mtf_shapes(dataset_str) mtf_dtype = tf.as_dtype(FLAGS.mtf_dtype) variable_dtype = mtf.VariableDType(mtf_dtype, mtf_dtype, mtf_dtype) # Import input features. x = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(images), mtf_images_shape) x = mtf.cast(x, mtf_dtype) # Import ground truth labels. t = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(labels), mtf_labels_shape) t = mtf.cast(t, mtf_dtype) # Transpose the blocks. if FLAGS.sampled_2d_slices: x = mtf.transpose(x, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, image_c_dim ]) t = mtf.transpose(t, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, label_c_dim ]) else: x = mtf.transpose(x, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, image_sz_dim, image_c_dim ]) t = mtf.transpose(t, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, image_sz_dim, label_c_dim ]) # Network. levels = [] all_bn_update_ops = [] # add levels with convolution or down-sampling for depth in range(FLAGS.network_depth): for n_conv in range(FLAGS.n_conv_per_block): if depth == 0 and n_conv == 0: # no dropout in 1st layer. dropout_keep_p = 1.0 else: dropout_keep_p = FLAGS.dropout_keep_p x, bn_update_ops = conv_with_spatial_partition( x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim, FLAGS.n_base_filters * (2**depth), dropout_keep_p, FLAGS.with_batch_norm, is_training, 'conv_{}_{}'.format(depth, n_conv), variable_dtype, 'conv_down_{}_{}'.format(depth, n_conv)) all_bn_update_ops.extend(bn_update_ops) levels.append(x) if depth < FLAGS.network_depth - 1: if FLAGS.sampled_2d_slices: x = mtf.layers.max_pool2d(x, ksize=(2, 2)) else: x = mtf.layers.max_pool3d(x, ksize=(2, 2, 2)) # add levels with up-convolution or up-sampling for depth in range(FLAGS.network_depth - 1)[::-1]: x = deconv_with_spatial_partition( x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim, FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p, 'conv_{}_{}'.format(depth, FLAGS.n_conv_per_block - 1), variable_dtype, 'deconv_{}_0'.format(depth)) x = mtf.concat([x, levels[depth]], concat_dim_name='conv_{}_{}'.format( depth, FLAGS.n_conv_per_block - 1)) for n_conv in range(FLAGS.n_conv_per_block): x, bn_update_ops = conv_with_spatial_partition( x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim, FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p, FLAGS.with_batch_norm, is_training, 'conv_{}_{}'.format(depth, n_conv), variable_dtype, 'conv_up_{}_{}'.format(depth, n_conv)) all_bn_update_ops.extend(bn_update_ops) # no dropout in the final layer. if FLAGS.sampled_2d_slices: y = mtf.layers.conv2d_with_blocks( x, mtf.Dimension('label_c', FLAGS.label_c), filter_size=(1, 1), strides=(1, 1), padding='SAME', h_blocks_dim=image_nx_dim, w_blocks_dim=image_ny_dim, variable_dtype=variable_dtype, name='final_conv_{}'.format(FLAGS.label_c), ) else: y = mtf.layers.conv3d_with_blocks( x, mtf.Dimension('label_c', FLAGS.label_c), filter_size=(1, 1, 1), strides=(1, 1, 1), padding='SAME', d_blocks_dim=image_nx_dim, h_blocks_dim=image_ny_dim, variable_dtype=variable_dtype, name='final_conv_{}'.format(FLAGS.label_c), ) # use mtf.constant to make sure there is no CPU-side constants. def scalar(v, dtype): return mtf.constant(mesh, v, shape=[], dtype=dtype) argmax_t = mtf.argmax(t, label_c_dim) liver_t = mtf.cast(mtf.equal(argmax_t, scalar(1, tf.int32)), mtf_dtype) lesion_t = mtf.cast(mtf.equal(argmax_t, scalar(2, tf.int32)), mtf_dtype) argmax_y = mtf.argmax(y, label_c_dim) lesion_y = mtf.cast(mtf.equal(argmax_y, scalar(2, tf.int32)), mtf_dtype) # summary of class ratios. lesion_pred_ratio = mtf.reduce_mean(lesion_y) lesion_label_ratio = mtf.reduce_mean(lesion_t) # summary of accuracy. accuracy = mtf.reduce_mean( mtf.cast(mtf.equal(argmax_y, argmax_t), mtf_dtype)) # Cross-entropy loss. Up-weight the liver region. pixel_loss = mtf.layers.softmax_cross_entropy_with_logits( y, t, label_c_dim) pixel_weight = scalar(1, mtf_dtype) + \ liver_t * scalar(FLAGS.xen_liver_weight - 1, mtf_dtype) + \ lesion_t * scalar(FLAGS.xen_lesion_weight - FLAGS.xen_liver_weight, mtf_dtype) loss_xen = mtf.reduce_mean(pixel_loss * pixel_weight) # Dice loss y_prob = mtf.softmax(y, reduced_dim=label_c_dim) lesion_prob = mtf.reduce_sum(mtf.slice(y_prob, 2, 1, 'label_c'), reduced_dim=mtf.Dimension('label_c', 1)) prob_intersect = mtf.reduce_sum(lesion_prob * lesion_t, output_shape=mtf.Shape([batch_dim])) prob_area_sum = mtf.reduce_sum(lesion_prob + lesion_t, output_shape=mtf.Shape([batch_dim])) loss_dice_per_case = mtf.reduce_mean( scalar(-2, mtf_dtype) * prob_intersect / (prob_area_sum + scalar(FLAGS.dice_epsilon, mtf_dtype))) loss_dice_global = scalar(-2, mtf_dtype) * mtf.reduce_sum( prob_intersect) / (mtf.reduce_sum(prob_area_sum) + scalar(FLAGS.dice_epsilon, mtf_dtype)) loss_dice = (loss_dice_per_case + loss_dice_global) * scalar( 0.5, mtf_dtype) loss = scalar(FLAGS.dice_loss_weight, mtf_dtype) * loss_dice + scalar( 1 - FLAGS.dice_loss_weight, mtf_dtype) * loss_xen intersect = mtf.reduce_sum(lesion_y * lesion_t, output_shape=mtf.Shape([batch_dim])) area_sum = mtf.reduce_sum(lesion_y + lesion_t, output_shape=mtf.Shape([batch_dim])) # summary of dice. dice_per_case = mtf.reduce_mean( scalar(2, mtf_dtype) * intersect / (area_sum + scalar(0.000001, mtf_dtype))) dice_global = scalar(2, mtf_dtype) * mtf.reduce_sum(intersect) / ( mtf.reduce_sum(area_sum) + scalar(0.000001, mtf_dtype)) eval_metrics = { 'lesion_pred_ratio': lesion_pred_ratio, 'lesion_label_ratio': lesion_label_ratio, 'accuracy_of_all_classes': accuracy, 'lesion_dice_per_case': dice_per_case, 'lesion_dice_global': dice_global, 'loss_xen': loss_xen, 'loss_dice': loss_dice, 'loss_dice_per_case': loss_dice_per_case, 'loss_dice_global': loss_dice_global, } if FLAGS.sampled_2d_slices: y_prob_downsampled = mtf.layers.avg_pool2d( y_prob, ksize=(FLAGS.pred_downsample, ) * 2) if FLAGS.output_ground_truth: lesion_gt_downsampled = mtf.layers.avg_pool2d( mtf.slice(t, 2, 1, 'label_c'), ksize=(FLAGS.pred_downsample, ) * 2) else: y_prob_downsampled = mtf.layers.avg_pool3d( y_prob, ksize=(FLAGS.pred_downsample, ) * 3) if FLAGS.output_ground_truth: lesion_gt_downsampled = mtf.layers.avg_pool3d( mtf.slice(t, 2, 1, 'label_c'), ksize=(FLAGS.pred_downsample, ) * 3) liver_prob_downsampled = mtf.slice(y_prob_downsampled, 1, 1, 'label_c') lesion_prob_downsampled = mtf.slice(y_prob_downsampled, 2, 1, 'label_c') preds = [ mtf.reduce_sum(liver_prob_downsampled, reduced_dim=mtf.Dimension('label_c', 1)), mtf.reduce_sum(lesion_prob_downsampled, reduced_dim=mtf.Dimension('label_c', 1)) ] if FLAGS.output_ground_truth: preds.append( mtf.reduce_sum(lesion_gt_downsampled, reduced_dim=mtf.Dimension('label_c', 1))) preds.extend([intersect, area_sum]) return preds, loss, eval_metrics, all_bn_update_ops
def _tile(x, n, tile_dim): # Simple tile function in MTF. return mtf.concat([x] * n, tile_dim.name)
def gradient_based_subword_tokenization(x, length_dim, max_subword_length=4, downsample=None, use_offsets=False, consider_chars_as_blocks=False, use_block_pos_embedding=False, share_block_kernel=False, memory_embeddings=0, context=None, block_mixing_mode=None, activation="softmax", downsample_function="mean"): """Implements GBSWT from Charformer. Args: x: a Tensor containing length_dim length_dim: a Dimension max_subword_length: integer downsample: integer. use_offsets: boolean. consider_chars_as_blocks: boolean. use_block_pos_embedding: boolean. share_block_kernel: boolean. memory_embeddings: integer. context: Context. block_mixing_mode: Str for block mixing. activation: Str for block ranking. downsample_function: Str, supports mean/linformer for now. Returns: a Tensor with the same shape as x. Raises: ValueError: if channels or depth don't match. """ # don't use this for now. del max_subword_length del memory_embeddings all_blocks = [] all_scores = [] tf.logging.info("GSW block layer") def _tile(x, n, tile_dim): # Simple tile function in MTF. return mtf.concat([x] * n, tile_dim.name) def _repeat(x, n, repeat_dim): # repeat function in MTF tmp_dim = mtf.Dimension("tmp", 1) expand_shape = mtf.Shape(x.shape.dims + [tmp_dim]) x = mtf.reshape(x, expand_shape) x = _tile(x, n, tmp_dim) output_shape = [] for dim in x.shape.dims: if dim.name == "tmp": continue if dim.name == repeat_dim.name: dim = mtf.Dimension(dim.name, dim.size * n) output_shape.append(dim) output_shape = mtf.Shape(output_shape) x = mtf.reshape(x, output_shape) return x def _combined_dim(dims): return mtf.Dimension(dims[0].name, mtf.Shape(dims).size) # compute all subword blocks # TODO(yitay): handle offsets to get all blocks if activation == "sigtanh": # one score for sigmoid tmp_dim = mtf.Dimension("block_score", 2) else: tmp_dim = mtf.Dimension("block_score", 1) model_dim = x.shape[-1] subword_blocks_width = [2, 3, 4] if consider_chars_as_blocks: subword_blocks_width += [1] if share_block_kernel: block_kernel_shape = mtf.Shape([model_dim, tmp_dim]) block_kernel = mtf.get_variable(x.mesh, "block_kernel", block_kernel_shape, initializer=None, dtype=context.variable_dtype) else: block_kernel = None for subword_len in subword_blocks_width: if use_block_pos_embedding: # this is turn off by default. It is meant to support cases like # parameterized pooling or other features. block_len_dim = mtf.Dimension(length_dim.name, subword_len) # TODO(vqtran): Consider other positional embeddings. block_pos_emb = sinusoid_positional_embedding_weights( context.mesh, block_len_dim, x.shape[-1], context.variable_dtype.activation_dtype) block_pos_emb = _repeat( block_pos_emb, math.ceil(length_dim.size / float(subword_len)), block_len_dim) if use_offsets: offset_space = subword_len else: offset_space = 1 for offsets in range(offset_space): if offsets > 0: xoff = mtf.shift(x, offsets, length_dim, wrap=False) if use_block_pos_embedding: block_pos_emb = mtf.shift(block_pos_emb, offsets, block_pos_emb.shape[-2], wrap=False) else: xoff = x tf.logging.info("SW len=%d offset=%d", subword_len, offsets) if length_dim.size % subword_len != 0: tf.logging.info("Not divisible by length") # add extra padding tokens pad_amt = int(subword_len) - int(length_dim.size % subword_len) kp = mtf.pad(xoff, [0, pad_amt], length_dim.name) else: kp = xoff if use_block_pos_embedding: kp += block_pos_emb bx = mtf.pool_tensor_1d( kp, pool_dim=kp.shape.get_dim_by_name("length"), reduce_fn=mtf.reduce_mean, pool_size=int(subword_len)) block_score = mtf.layers.dense(bx, [tmp_dim], use_bias=False, name="bx", reduced_dims=[model_dim], variable_dtype=None, kernel_weights=block_kernel) expand_bx = _repeat(bx, subword_len, length_dim) expand_scores = _repeat(block_score, subword_len, length_dim) if offsets > 0: # add offset. expand_bx = mtf.pad(expand_bx, [offsets, 0], length_dim.name) expand_scores = mtf.pad(expand_scores, [offsets, 0], length_dim.name) new_len = expand_bx.shape.get_dim_by_name(length_dim.name) if new_len.size < length_dim.size: pad_amt = new_len.size - length_dim.size expand_bx = mtf.pad(expand_bx, [0, pad_amt], length_dim.name) expand_scores = mtf.pad(expand_scores, [0, pad_amt], length_dim.name) elif new_len.size > length_dim.size: expand_bx = mtf.slice(expand_bx, 0, length_dim.size, length_dim.name) expand_scores = mtf.slice(expand_scores, 0, length_dim.size, length_dim.name) new_tmp_dim = mtf.Dimension("extra_dim", 1) expand_shape = mtf.Shape(expand_bx.shape.dims + [new_tmp_dim]) expand_scores_shape = mtf.Shape(expand_scores.shape.dims + [new_tmp_dim]) expand_bx = mtf.reshape(expand_bx, expand_shape) expand_scores = mtf.reshape(expand_scores, expand_scores_shape) all_blocks.append(expand_bx) all_scores.append(expand_scores) all_blocks = mtf.concat(all_blocks, new_tmp_dim.name) all_scores = mtf.concat(all_scores, new_tmp_dim.name) tf.logging.info(all_blocks) new_tmp_dim = all_blocks.shape.get_dim_by_name("extra_dim") combined_dim = _combined_dim([new_tmp_dim, tmp_dim]) block_net_shape = all_scores.shape - tmp_dim - new_tmp_dim + combined_dim block_net = mtf.reshape(all_scores, block_net_shape) if block_mixing_mode == "score_attention": tf.logging.info("Using score attention") att = mtf.einsum([block_net, block_net], reduced_dims=[new_tmp_dim]) tf.logging.info(block_net) att = mtf.softmax(att, reduced_dim=att.shape[-1]) block_net = mtf.einsum([att, block_net], output_shape=block_net.shape) tf.logging.info(block_net) if activation == "softmax": block_net = mtf.softmax(block_net, reduced_dim=new_tmp_dim) elif activation == "tanh": tf.logging.info("Using tanh") block_net = mtf.tanh(block_net) all_blocks = block_net * all_blocks all_blocks = mtf.reduce_sum(all_blocks, reduced_dim=new_tmp_dim) output = all_blocks if downsample: output_length = output.shape.get_dim_by_name("length") if output_length.size % int(downsample) != 0: pad_amt = int(downsample) - int( output_length.size % int(downsample)) output = mtf.pad(output, [0, pad_amt], output_length.name) if downsample_function == "mean": output = mtf.pool_tensor_1d( output, pool_dim=output.shape.get_dim_by_name("length"), reduce_fn=mtf.reduce_mean, pool_size=int(downsample)) else: raise ValueError("Downsampling function not implemeneted.") return output
def ffn_layer_multi_inputs(self, context, mask, inputs_list, ffn_layer_type="dense", kernel_initializer=None, activation=None, preprocess=False, postprocess=False): """Implements a Feed-forward layer with multiple inputs, pad-removing, etc. Args: context: mtf context mask: mask inputs_list: list of input tensors ffn_layer_type: dense / dense_dropconnect/ dense_relu_dense kernel_initializer: kernel initializer activation: activation function preprocess: if preprocess the input --> default: layer-norm postprocess: if postprocess the output --> default: drop-out and residual Returns: a tensor Raises: ValueError: Unknown ffn_layer type. """ # need at least one inputs num_inputs = len(inputs_list) assert num_inputs > 0 if preprocess: # In case of having more than one input to the ffn, # we just apply layer norm on them independently as preprocessing for i, inputs in enumerate(inputs_list): inputs_list[i] = self._layer_norm( context, (inputs * mask) if mask else inputs) # the output size is the hidden size of the main inputs ffn_inputs = inputs_list[0] if len(inputs_list) != 1: ffn_inputs = mtf.concat(inputs_list, context.model.model_dim.name) if ffn_layer_type == "dense": # last_dims = [ # mtf.Dimension(ffn_inputs.shape.dims[-1].name, hidden_size) # ] output = mtf.layers.dense(ffn_inputs, reduced_dims=[ffn_inputs.shape.dims[-1]], new_dims=[context.model.model_dim], activation=activation, use_bias=True, variable_dtype=context.variable_dtype, expert_dims=context.model.ensemble_dims, kernel_initializer=kernel_initializer) elif ffn_layer_type == "dense_relu_dense": output = mtf.layers.dense_relu_dense( ffn_inputs, hidden_channels=context.model.model_dim, is_training=context.train, dropout=self.relu_dropout) else: raise ValueError("Unknown ffn_layer type: %s" % ffn_layer_type) if postprocess: output = self._layer_norm(context, (output * mask) if mask else output) return output
def get_timing_signal_1d(self, context, length, channels, min_timescale=1.0, max_timescale=1.0e4, start_index=0): """Gets a bunch of sinusoids of different frequencies. Each channel of the input Tensor is incremented by a sinusoid of a different frequency and phase. This allows attention to learn to use absolute and relative positions. Timing signals should be added to some precursors of both the query and the memory inputs to attention. The use of relative position is possible because sin(x+y) and cos(x+y) can be expressed in terms of y, sin(x) and cos(x). In particular, we use a geometric sequence of timescales starting with min_timescale and ending with max_timescale. The number of different timescales is equal to channels / 2. For each timescale, we generate the two sinusoidal signals sin(timestep/timescale) and cos(timestep/timescale). All of these sinusoids are concatenated in the channels dimension. Args: context: mtf context. length: a mtf.Dimension, length of timing signal sequence. channels: a mtf.Dimension, size of timing embeddings to create. The number of different timescales is equal to channels / 2. min_timescale: a float max_timescale: a float start_index: index of first position Returns: a Tensor of timing signals [1, length, channels] """ position = context.get_position() + start_index num_timescales = mtf.constant(context.mesh, channels.size // 2) log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / mtf.maximum(num_timescales - 1, 1)) channel_dim_name = channels.name inv_timescales = (min_timescale * mtf.exp( mtf.mtf_range(context.mesh, mtf.Dimension(channel_dim_name, channels.size // 2), context.activation_dtype) * -log_timescale_increment) ) scaled_time = position * inv_timescales # Please note that this slightly differs from the published paper. # See a discussion here: # https://github.com/tensorflow/tensor2tensor/pull/177 # concat_dim_name = scaled_time.shape.dimension_names[1] concat_dim_name = channels.name signal = mtf.concat( [mtf.sin(scaled_time), mtf.cos(scaled_time)], concat_dim_name=concat_dim_name) if channels.size % 2 != 0: raise NotImplementedError("Odd channel size not implemented.") new_dims = [mtf.Dimension("expanded", 1) ] + length.shape.dims + channels.shape.dim signal = mtf.reshape(signal, mtf.Shape(new_dims)) return signal
def _call_internal(self, context, inputs, targets=None, attributes=None, z=None): """Compute logits based on inputs (all positions in parallel). Also updates context if applicable. Args: context: a Context inputs: a Tensor targets: an optional Tensor attributes: an optional Tensor Returns:g logits: a Tensor with shape [<batch_dims>, length_dim, output_vocab_dim] """ mesh = inputs.mesh if self.ensemble_dim and self.ensemble_dim not in inputs.shape.dims: # Training an ensemble where all models are trained on the same examples. inputs = mtf.broadcast(inputs, [self.ensemble_dim] + inputs.shape.dims) if self.ensemble_dim not in attributes.shape.dims: attributes = mtf.broadcast(attributes, [self.ensemble_dim] + attributes.shape.dims) if targets: targets = mtf.broadcast(targets, [self.ensemble_dim] + targets.shape.dims) if "embedding" in context.shared_params: vocab_embedding = context.shared_params["embedding"] else: vocab_embedding = VocabEmbedding(mesh, self.input_vocab_dim, self.model_dim, context.variable_dtype, name="embedding", ensemble_dim=self.ensemble_dim) x = vocab_embedding.ids_to_embedding(inputs) if self.positional_embedding: if "positional_embedding" in context.shared_params: pos_emb_var = context.shared_params["positional_embedding"] else: pos_emb_var = mtf.layers.embedding_weights( mesh, self.max_length_dim, self.model_dim, context.variable_dtype, "positional_embedding", ensemble_dim=self.ensemble_dim) if (context.length_dim is not None and context.length_dim.size > self.max_length_dim.size): message = ( "Length dimenison exceeds size of positional embedding table. " "length_dim.size > max_length_dim.size %s vs %s." % (context.length_dim, self.max_length_dim)) if context.position_is_default: # Definitely getting overflow in this case. raise ValueError(message) else: tf.logging.warning( message + " This may be OK if there are several shorter sequences packed " "together. Otherwise, the later positions will get zeros." ) if context.position_is_default: pos_emb = mtf.rename_dimension( mtf.slice(pos_emb_var, 0, context.length_dim.size, self.max_length_dim.name), self.max_length_dim.name, context.length_dim.name) else: pos_emb = mtf.gather(pos_emb_var, context.position, self.max_length_dim, output_shape=x.shape) x += pos_emb if self.attribute_embedding: if "attribute_embedding" in context.shared_params: att_emb_var = context.shared_params["attribute_embedding"] else: att_emb_var = mtf.layers.embedding_weights( mesh, self.attribute_dim, self.model_dim, context.variable_dtype, "attribute_embedding", ensemble_dim=self.ensemble_dim) att_emb = mtf.gather(att_emb_var, attributes, self.attribute_dim, output_shape=x.shape) # Addition of x and attribute # x *= LAMBDA_ATTRIBUTE * sty_emb # # Concatenation of x and attribute x_attribute = mtf.concat([x, att_emb], self.model_dim.name) x = mtf.layers.dense(x_attribute, self.model_dim, activation=None, variable_dtype=context.variable_dtype, name="comb_x_attribute") if z: z = mtf.layers.dense(z, self.model_dim, activation=None, variable_dtype=context.variable_dtype, name="z") # raise ValueError("x shape=%s , z shape=%s" % (x.shape, z.shape)) x += z x = self.layer_stack.call(context, x) if self.output_vocab_dim is None: return x if self.shared_embedding_and_softmax_weights: logits = vocab_embedding.hidden_to_logits(x) else: logits = mtf.layers.dense(x, self.output_vocab_dim, use_bias=False, variable_dtype=context.variable_dtype, reduced_dims=x.shape.dims[-1:], name="logits") if targets is not None and context.losses is not None: context.losses.append( self._compute_loss(context, logits, targets, self.output_vocab_dim)) if self.ensemble_dim: logits = reduce_ensemble_logits(logits, self.ensemble_dim, self.output_vocab_dim) return logits