Пример #1
0
 def _sigmoid_tree(self, tensor):
   """Create probability distribution along gates dim using a sigmoid tree."""
   gamma = mtf.split(
       mtf.sigmoid(tensor), self._pre_gates_dim, self._pre_gates_dim.size)
   return mtf.concat([
       gamma[0] * gamma[1],
       gamma[0] * (1 - gamma[1]),
       (1 - gamma[0]) * gamma[2],
       (1 - gamma[0]) * (1 - gamma[2]),
   ], self._gates_dim.name)
Пример #2
0
 def hidden_to_logits(self, hidden, context):
   # Each cluster returns the logits for only the tokens with itself, so their
   # concatenation is the full logits.
   return mtf.concat(
       [
           cluster.hidden_to_logits(hidden, context=context)
           for cluster in self._clusters
       ],
       concat_dim_name=self._vocab_dim.name,
   )
Пример #3
0
def Concat(tsr_lst, name=None):
    assert all(tsr_lst[0].shape[:-1] == t.shape[:-1] for t in tsr_lst[1:])

    concat_dim_name = utils.RandName()
    concat_tsrs = []
    for t in tsr_lst:
        assert not t.shape[-1].name.startswith('axis')
        t = mt.rename_dimension(t, t.shape[-1].name, concat_dim_name)
        concat_tsrs.append(t)

    return mtf.concat(concat_tsrs, concat_dim_name, name)
    def hidden_to_logits(self, hidden: mtf.Tensor,
                         context: transformer.Context) -> mtf.Tensor:
        """Function called by mtf transformer to get the logits.

    Args:
      hidden: an mtf.Tensor, hidden model states of the final decoder layer.
      context: a transformer.Context, the context used for the call to the
        transformer.

    Returns:
      An mtf.Tensor, the logits.
    """
        hidden *= self._output_dim.size**-0.5

        component_contexts = mtf.einsum([
            mtf.rename_dimension(hidden, self._output_dim.name,
                                 self._copy_output_dim.name),
            self._context_weights,
        ],
                                        reduced_dims=[self._copy_output_dim])
        component_contexts = mtf.tanh(component_contexts +
                                      self._context_weights_bias)
        component_logits = mtf.einsum(
            [component_contexts, self._embedding_weights],
            reduced_dims=[self._output_dim])
        component_logits = self._dropout(component_logits, context)

        prior_tanh = mtf.tanh(
            mtf.einsum([self._prior_weights, hidden],
                       reduced_dims=[self._output_dim]) +
            self._prior_weights_bias)
        prior_tanh = self._dropout(prior_tanh, context)
        prior_shared_logits = mtf.einsum([self._prior_gates_vector, hidden],
                                         reduced_dims=[self._output_dim])
        prior_frequent_vocab_logits = (
            mtf.einsum([self._prior_vocab_vector, prior_tanh]) +
            prior_shared_logits + self._prior_bias)
        prior_logits = mtf.concat([
            prior_frequent_vocab_logits,
            mtf.ones(self._mesh,
                     mtf.Shape([self._rare_vocab_dim]),
                     dtype=prior_shared_logits.dtype) * prior_shared_logits
        ], self._vocab_dim.name)
        if context.train and self._noise_std_dev != 0.0:
            prior_logits += mtf.random_normal(self._mesh,
                                              prior_logits.shape,
                                              stddev=self._noise_std_dev)
        prior_proportions = self._sigmoid_tree(prior_logits)

        logits = mtf.einsum([component_logits, prior_proportions],
                            reduced_dims=[self._gates_dim])
        return self._rearrange_sentinels(logits)
    def add_position_timing_signal_func(self, context, x, step):
        """Add n-dimensional embedding as the position (horizontal) timing signal.

    Args:
      context: mtf context
      x: a tensor with shape [batch, length, depth]
      step: step

    Returns:
      a Tensor with the same shape as x.

    """

        if not self.position_start_index:
            index = 0

        elif self.position_start_index == "random":
            # Shift all positions randomly
            # TODO(dehghani): What would be reasonable for max number of shift?
            index = mtf.random_uniform(context.mesh, [],
                                       maxval=x.shape.dims[1].size,
                                       dtype=tf.int32)

        elif self.position_start_index == "step":
            # Shift positions based on the step
            if self.recurrence_type == "act":
                num_steps = self.act_max_steps
            else:
                num_steps = self.num_rec_steps
            index = mtf.cast(x.shape.dims[1].size * step / num_steps,
                             dtype=tf.int32)

        length = context.length_dim
        channels = context.model.model_dim
        signal = self.get_timing_signal_1d(context,
                                           length,
                                           channels,
                                           start_index=index)

        if self.add_or_concat_timing_signal == "add":
            x_with_timing = x + mtf.cast(signal, x.dtype)
        # Unimplemented
        if self.add_or_concat_timing_signal == "concat":
            batch_dim = x.shape.dims[0]
            out_shape = mtf.Shape([batch_dim] + signal.shape.dims[1:])
            signal_tiled = mtf.broadcast(signal, out_shape)
            x_with_timing = mtf.concat(
                (x, signal_tiled),
                concat_dim_name=signal_tiled.dimension_names[-1])

        return x_with_timing
Пример #6
0
def memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype,
                      mesh):
    """memory / key values from all attention paper"""

    dim_mem_kv = mtf.Dimension("mem_kv_sequence", num_mem_kv)
    emb_dim = k.shape[-1]
    mem_std = 1 / math.sqrt(emb_dim.size)

    mem_k = mtf.get_variable(
        mesh,
        "mem_k",
        mtf.Shape([dim_mem_kv, dim_heads, emb_dim]),
        initializer=tf.random_normal_initializer(stddev=mem_std),
        master_dtype=variable_dtype.master_dtype,
        slice_dtype=variable_dtype.slice_dtype,
        activation_dtype=variable_dtype.activation_dtype,
    )
    mem_v = mtf.get_variable(
        mesh,
        "mem_v",
        mtf.Shape([dim_mem_kv, dim_heads, emb_dim]),
        initializer=tf.random_normal_initializer(stddev=mem_std),
        master_dtype=variable_dtype.master_dtype,
        slice_dtype=variable_dtype.slice_dtype,
        activation_dtype=variable_dtype.activation_dtype)

    mem_k, mem_v = map(
        lambda t: mtf.broadcast(t, [dim_batch, dim_mem_kv, dim_heads, emb_dim]
                                ), (mem_k, mem_v))
    mem_k, mem_v = map(
        lambda t: mtf.rename_dimension(t, "mem_kv_sequence", "sequence"),
        (mem_k, mem_v))

    k = mtf.concat([mem_k, k], "sequence")
    v = mtf.concat([mem_v, v], "sequence")
    return k, v
Пример #7
0
  def _hidden_to_logits(self, hidden, context):
    """Actually compute the logits over the entire vocab."""
    head_size = self._head_cluster.end_token_id
    # Note that computing the log softmax is equivalent to computing the logits.
    head_log_softmax = self._head_cluster.compute_log_softmax(hidden, context)
    logits = [
        self._head_cluster.get_log_softmax_prefix(head_log_softmax, head_size)
    ]

    for i, cluster in enumerate(self._tail_clusters):
      tail_log_softmax = cluster.compute_log_softmax(hidden, context)
      cluster_softmax = self._head_cluster.get_log_softmax_value(
          head_log_softmax, head_size + i)
      logits.append(cluster_softmax + tail_log_softmax)
    return mtf.concat(logits, concat_dim_name=self._vocab_dim.name)
Пример #8
0
def unet_with_spatial_partition(mesh, mesh_impl, dataset_str, images, labels):
    """Builds the UNet model graph, train op and eval metrics.

  Args:
    mesh: a MeshTensorflow.mesh object.
    mesh_impl: a mesh implementation, such as SimdMeshImpl and
      PlacementMeshImpl.
    dataset_str: a string of either train or eval. This is used for batch_norm.
    images: a laid out Tensor with shape [batch, x, y, num_channels]
      or [batch, x, y, z, num_channels].
    labels: a laid out Tensor with shape [batch, x, y, num_classes]
      or [batch, x, y, z, num_classes].

  Returns:
    Prediction and loss.
  """

    is_training = (dataset_str == 'train')
    if dataset_str == 'train':
        batch_dim = mtf.Dimension('batch', FLAGS.batch_size_train)
    else:
        assert dataset_str == 'eval'
        batch_dim = mtf.Dimension('batch', FLAGS.batch_size_eval)
    image_nx_dim = mtf.Dimension('image_nx_block', FLAGS.image_nx_block)
    image_ny_dim = mtf.Dimension('image_ny_block', FLAGS.image_ny_block)
    image_sx_dim = mtf.Dimension('image_sx_block',
                                 FLAGS.ct_resolution // FLAGS.image_nx_block)
    image_sy_dim = mtf.Dimension('image_sy_block',
                                 FLAGS.ct_resolution // FLAGS.image_ny_block)
    image_sz_dim = mtf.Dimension('image_sz_block', FLAGS.ct_resolution)
    image_c_dim = mtf.Dimension('image_c', FLAGS.image_c)
    label_c_dim = mtf.Dimension('label_c', FLAGS.label_c)
    mtf_images_shape, mtf_labels_shape = get_input_mtf_shapes(dataset_str)

    mtf_dtype = tf.as_dtype(FLAGS.mtf_dtype)
    variable_dtype = mtf.VariableDType(mtf_dtype, mtf_dtype, mtf_dtype)

    # Import input features.
    x = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(images),
                                   mtf_images_shape)
    x = mtf.cast(x, mtf_dtype)

    # Import ground truth labels.
    t = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(labels),
                                   mtf_labels_shape)
    t = mtf.cast(t, mtf_dtype)

    # Transpose the blocks.
    if FLAGS.sampled_2d_slices:
        x = mtf.transpose(x, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            image_c_dim
        ])

        t = mtf.transpose(t, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            label_c_dim
        ])
    else:
        x = mtf.transpose(x, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            image_sz_dim, image_c_dim
        ])

        t = mtf.transpose(t, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            image_sz_dim, label_c_dim
        ])

    # Network.
    levels = []
    all_bn_update_ops = []
    # add levels with convolution or down-sampling
    for depth in range(FLAGS.network_depth):
        for n_conv in range(FLAGS.n_conv_per_block):
            if depth == 0 and n_conv == 0:
                # no dropout in 1st layer.
                dropout_keep_p = 1.0
            else:
                dropout_keep_p = FLAGS.dropout_keep_p
            x, bn_update_ops = conv_with_spatial_partition(
                x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
                FLAGS.n_base_filters * (2**depth), dropout_keep_p,
                FLAGS.with_batch_norm, is_training,
                'conv_{}_{}'.format(depth, n_conv), variable_dtype,
                'conv_down_{}_{}'.format(depth, n_conv))
            all_bn_update_ops.extend(bn_update_ops)
        levels.append(x)

        if depth < FLAGS.network_depth - 1:
            if FLAGS.sampled_2d_slices:
                x = mtf.layers.max_pool2d(x, ksize=(2, 2))
            else:
                x = mtf.layers.max_pool3d(x, ksize=(2, 2, 2))

    # add levels with up-convolution or up-sampling
    for depth in range(FLAGS.network_depth - 1)[::-1]:
        x = deconv_with_spatial_partition(
            x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
            FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p,
            'conv_{}_{}'.format(depth, FLAGS.n_conv_per_block - 1),
            variable_dtype, 'deconv_{}_0'.format(depth))
        x = mtf.concat([x, levels[depth]],
                       concat_dim_name='conv_{}_{}'.format(
                           depth, FLAGS.n_conv_per_block - 1))

        for n_conv in range(FLAGS.n_conv_per_block):
            x, bn_update_ops = conv_with_spatial_partition(
                x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
                FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p,
                FLAGS.with_batch_norm, is_training,
                'conv_{}_{}'.format(depth, n_conv), variable_dtype,
                'conv_up_{}_{}'.format(depth, n_conv))
            all_bn_update_ops.extend(bn_update_ops)

    # no dropout in the final layer.
    if FLAGS.sampled_2d_slices:
        y = mtf.layers.conv2d_with_blocks(
            x,
            mtf.Dimension('label_c', FLAGS.label_c),
            filter_size=(1, 1),
            strides=(1, 1),
            padding='SAME',
            h_blocks_dim=image_nx_dim,
            w_blocks_dim=image_ny_dim,
            variable_dtype=variable_dtype,
            name='final_conv_{}'.format(FLAGS.label_c),
        )
    else:
        y = mtf.layers.conv3d_with_blocks(
            x,
            mtf.Dimension('label_c', FLAGS.label_c),
            filter_size=(1, 1, 1),
            strides=(1, 1, 1),
            padding='SAME',
            d_blocks_dim=image_nx_dim,
            h_blocks_dim=image_ny_dim,
            variable_dtype=variable_dtype,
            name='final_conv_{}'.format(FLAGS.label_c),
        )

    # use mtf.constant to make sure there is no CPU-side constants.
    def scalar(v, dtype):
        return mtf.constant(mesh, v, shape=[], dtype=dtype)

    argmax_t = mtf.argmax(t, label_c_dim)
    liver_t = mtf.cast(mtf.equal(argmax_t, scalar(1, tf.int32)), mtf_dtype)
    lesion_t = mtf.cast(mtf.equal(argmax_t, scalar(2, tf.int32)), mtf_dtype)

    argmax_y = mtf.argmax(y, label_c_dim)
    lesion_y = mtf.cast(mtf.equal(argmax_y, scalar(2, tf.int32)), mtf_dtype)

    # summary of class ratios.
    lesion_pred_ratio = mtf.reduce_mean(lesion_y)
    lesion_label_ratio = mtf.reduce_mean(lesion_t)

    # summary of accuracy.
    accuracy = mtf.reduce_mean(
        mtf.cast(mtf.equal(argmax_y, argmax_t), mtf_dtype))

    # Cross-entropy loss. Up-weight the liver region.
    pixel_loss = mtf.layers.softmax_cross_entropy_with_logits(
        y, t, label_c_dim)
    pixel_weight = scalar(1, mtf_dtype) + \
        liver_t * scalar(FLAGS.xen_liver_weight - 1, mtf_dtype) + \
        lesion_t * scalar(FLAGS.xen_lesion_weight - FLAGS.xen_liver_weight,
                          mtf_dtype)
    loss_xen = mtf.reduce_mean(pixel_loss * pixel_weight)

    # Dice loss
    y_prob = mtf.softmax(y, reduced_dim=label_c_dim)
    lesion_prob = mtf.reduce_sum(mtf.slice(y_prob, 2, 1, 'label_c'),
                                 reduced_dim=mtf.Dimension('label_c', 1))
    prob_intersect = mtf.reduce_sum(lesion_prob * lesion_t,
                                    output_shape=mtf.Shape([batch_dim]))
    prob_area_sum = mtf.reduce_sum(lesion_prob + lesion_t,
                                   output_shape=mtf.Shape([batch_dim]))
    loss_dice_per_case = mtf.reduce_mean(
        scalar(-2, mtf_dtype) * prob_intersect /
        (prob_area_sum + scalar(FLAGS.dice_epsilon, mtf_dtype)))
    loss_dice_global = scalar(-2, mtf_dtype) * mtf.reduce_sum(
        prob_intersect) / (mtf.reduce_sum(prob_area_sum) +
                           scalar(FLAGS.dice_epsilon, mtf_dtype))

    loss_dice = (loss_dice_per_case + loss_dice_global) * scalar(
        0.5, mtf_dtype)

    loss = scalar(FLAGS.dice_loss_weight, mtf_dtype) * loss_dice + scalar(
        1 - FLAGS.dice_loss_weight, mtf_dtype) * loss_xen

    intersect = mtf.reduce_sum(lesion_y * lesion_t,
                               output_shape=mtf.Shape([batch_dim]))
    area_sum = mtf.reduce_sum(lesion_y + lesion_t,
                              output_shape=mtf.Shape([batch_dim]))
    # summary of dice.
    dice_per_case = mtf.reduce_mean(
        scalar(2, mtf_dtype) * intersect /
        (area_sum + scalar(0.000001, mtf_dtype)))
    dice_global = scalar(2, mtf_dtype) * mtf.reduce_sum(intersect) / (
        mtf.reduce_sum(area_sum) + scalar(0.000001, mtf_dtype))

    eval_metrics = {
        'lesion_pred_ratio': lesion_pred_ratio,
        'lesion_label_ratio': lesion_label_ratio,
        'accuracy_of_all_classes': accuracy,
        'lesion_dice_per_case': dice_per_case,
        'lesion_dice_global': dice_global,
        'loss_xen': loss_xen,
        'loss_dice': loss_dice,
        'loss_dice_per_case': loss_dice_per_case,
        'loss_dice_global': loss_dice_global,
    }

    if FLAGS.sampled_2d_slices:
        y_prob_downsampled = mtf.layers.avg_pool2d(
            y_prob, ksize=(FLAGS.pred_downsample, ) * 2)
        if FLAGS.output_ground_truth:
            lesion_gt_downsampled = mtf.layers.avg_pool2d(
                mtf.slice(t, 2, 1, 'label_c'),
                ksize=(FLAGS.pred_downsample, ) * 2)
    else:
        y_prob_downsampled = mtf.layers.avg_pool3d(
            y_prob, ksize=(FLAGS.pred_downsample, ) * 3)
        if FLAGS.output_ground_truth:
            lesion_gt_downsampled = mtf.layers.avg_pool3d(
                mtf.slice(t, 2, 1, 'label_c'),
                ksize=(FLAGS.pred_downsample, ) * 3)

    liver_prob_downsampled = mtf.slice(y_prob_downsampled, 1, 1, 'label_c')
    lesion_prob_downsampled = mtf.slice(y_prob_downsampled, 2, 1, 'label_c')
    preds = [
        mtf.reduce_sum(liver_prob_downsampled,
                       reduced_dim=mtf.Dimension('label_c', 1)),
        mtf.reduce_sum(lesion_prob_downsampled,
                       reduced_dim=mtf.Dimension('label_c', 1))
    ]

    if FLAGS.output_ground_truth:
        preds.append(
            mtf.reduce_sum(lesion_gt_downsampled,
                           reduced_dim=mtf.Dimension('label_c', 1)))

    preds.extend([intersect, area_sum])

    return preds, loss, eval_metrics, all_bn_update_ops
Пример #9
0
 def _tile(x, n, tile_dim):
     # Simple tile function in MTF.
     return mtf.concat([x] * n, tile_dim.name)
Пример #10
0
def gradient_based_subword_tokenization(x,
                                        length_dim,
                                        max_subword_length=4,
                                        downsample=None,
                                        use_offsets=False,
                                        consider_chars_as_blocks=False,
                                        use_block_pos_embedding=False,
                                        share_block_kernel=False,
                                        memory_embeddings=0,
                                        context=None,
                                        block_mixing_mode=None,
                                        activation="softmax",
                                        downsample_function="mean"):
    """Implements GBSWT from Charformer.

  Args:
    x: a Tensor containing length_dim
    length_dim: a Dimension
    max_subword_length: integer
    downsample: integer.
    use_offsets: boolean.
    consider_chars_as_blocks: boolean.
    use_block_pos_embedding: boolean.
    share_block_kernel: boolean.
    memory_embeddings: integer.
    context: Context.
    block_mixing_mode: Str for block mixing.
    activation: Str for block ranking.
    downsample_function: Str, supports mean/linformer for now.

  Returns:
    a Tensor with the same shape as x.

  Raises:
    ValueError: if channels or depth don't match.
  """
    # don't use this for now.
    del max_subword_length
    del memory_embeddings
    all_blocks = []
    all_scores = []
    tf.logging.info("GSW block layer")

    def _tile(x, n, tile_dim):
        # Simple tile function in MTF.
        return mtf.concat([x] * n, tile_dim.name)

    def _repeat(x, n, repeat_dim):
        # repeat function in MTF
        tmp_dim = mtf.Dimension("tmp", 1)
        expand_shape = mtf.Shape(x.shape.dims + [tmp_dim])
        x = mtf.reshape(x, expand_shape)
        x = _tile(x, n, tmp_dim)
        output_shape = []
        for dim in x.shape.dims:
            if dim.name == "tmp":
                continue
            if dim.name == repeat_dim.name:
                dim = mtf.Dimension(dim.name, dim.size * n)
            output_shape.append(dim)
        output_shape = mtf.Shape(output_shape)
        x = mtf.reshape(x, output_shape)
        return x

    def _combined_dim(dims):
        return mtf.Dimension(dims[0].name, mtf.Shape(dims).size)

    # compute all subword blocks
    # TODO(yitay): handle offsets to get all blocks
    if activation == "sigtanh":
        # one score for sigmoid
        tmp_dim = mtf.Dimension("block_score", 2)
    else:
        tmp_dim = mtf.Dimension("block_score", 1)

    model_dim = x.shape[-1]
    subword_blocks_width = [2, 3, 4]

    if consider_chars_as_blocks:
        subword_blocks_width += [1]

    if share_block_kernel:
        block_kernel_shape = mtf.Shape([model_dim, tmp_dim])
        block_kernel = mtf.get_variable(x.mesh,
                                        "block_kernel",
                                        block_kernel_shape,
                                        initializer=None,
                                        dtype=context.variable_dtype)
    else:
        block_kernel = None

    for subword_len in subword_blocks_width:
        if use_block_pos_embedding:
            # this is turn off by default. It is meant to support cases like
            # parameterized pooling or other features.
            block_len_dim = mtf.Dimension(length_dim.name, subword_len)
            # TODO(vqtran): Consider other positional embeddings.
            block_pos_emb = sinusoid_positional_embedding_weights(
                context.mesh, block_len_dim, x.shape[-1],
                context.variable_dtype.activation_dtype)
            block_pos_emb = _repeat(
                block_pos_emb, math.ceil(length_dim.size / float(subword_len)),
                block_len_dim)
        if use_offsets:
            offset_space = subword_len
        else:
            offset_space = 1
        for offsets in range(offset_space):
            if offsets > 0:
                xoff = mtf.shift(x, offsets, length_dim, wrap=False)
                if use_block_pos_embedding:
                    block_pos_emb = mtf.shift(block_pos_emb,
                                              offsets,
                                              block_pos_emb.shape[-2],
                                              wrap=False)
            else:
                xoff = x
            tf.logging.info("SW len=%d offset=%d", subword_len, offsets)
            if length_dim.size % subword_len != 0:
                tf.logging.info("Not divisible by length")
                # add extra padding tokens
                pad_amt = int(subword_len) - int(length_dim.size % subword_len)
                kp = mtf.pad(xoff, [0, pad_amt], length_dim.name)
            else:
                kp = xoff

            if use_block_pos_embedding:
                kp += block_pos_emb

            bx = mtf.pool_tensor_1d(
                kp,
                pool_dim=kp.shape.get_dim_by_name("length"),
                reduce_fn=mtf.reduce_mean,
                pool_size=int(subword_len))
            block_score = mtf.layers.dense(bx, [tmp_dim],
                                           use_bias=False,
                                           name="bx",
                                           reduced_dims=[model_dim],
                                           variable_dtype=None,
                                           kernel_weights=block_kernel)

            expand_bx = _repeat(bx, subword_len, length_dim)
            expand_scores = _repeat(block_score, subword_len, length_dim)
            if offsets > 0:
                # add offset.
                expand_bx = mtf.pad(expand_bx, [offsets, 0], length_dim.name)
                expand_scores = mtf.pad(expand_scores, [offsets, 0],
                                        length_dim.name)
            new_len = expand_bx.shape.get_dim_by_name(length_dim.name)
            if new_len.size < length_dim.size:
                pad_amt = new_len.size - length_dim.size
                expand_bx = mtf.pad(expand_bx, [0, pad_amt], length_dim.name)
                expand_scores = mtf.pad(expand_scores, [0, pad_amt],
                                        length_dim.name)
            elif new_len.size > length_dim.size:
                expand_bx = mtf.slice(expand_bx, 0, length_dim.size,
                                      length_dim.name)
                expand_scores = mtf.slice(expand_scores, 0, length_dim.size,
                                          length_dim.name)

            new_tmp_dim = mtf.Dimension("extra_dim", 1)
            expand_shape = mtf.Shape(expand_bx.shape.dims + [new_tmp_dim])
            expand_scores_shape = mtf.Shape(expand_scores.shape.dims +
                                            [new_tmp_dim])
            expand_bx = mtf.reshape(expand_bx, expand_shape)
            expand_scores = mtf.reshape(expand_scores, expand_scores_shape)
            all_blocks.append(expand_bx)
            all_scores.append(expand_scores)

    all_blocks = mtf.concat(all_blocks, new_tmp_dim.name)
    all_scores = mtf.concat(all_scores, new_tmp_dim.name)
    tf.logging.info(all_blocks)
    new_tmp_dim = all_blocks.shape.get_dim_by_name("extra_dim")
    combined_dim = _combined_dim([new_tmp_dim, tmp_dim])
    block_net_shape = all_scores.shape - tmp_dim - new_tmp_dim + combined_dim
    block_net = mtf.reshape(all_scores, block_net_shape)

    if block_mixing_mode == "score_attention":
        tf.logging.info("Using score attention")
        att = mtf.einsum([block_net, block_net], reduced_dims=[new_tmp_dim])
        tf.logging.info(block_net)
        att = mtf.softmax(att, reduced_dim=att.shape[-1])
        block_net = mtf.einsum([att, block_net], output_shape=block_net.shape)
        tf.logging.info(block_net)

    if activation == "softmax":
        block_net = mtf.softmax(block_net, reduced_dim=new_tmp_dim)
    elif activation == "tanh":
        tf.logging.info("Using tanh")
        block_net = mtf.tanh(block_net)

    all_blocks = block_net * all_blocks
    all_blocks = mtf.reduce_sum(all_blocks, reduced_dim=new_tmp_dim)
    output = all_blocks

    if downsample:
        output_length = output.shape.get_dim_by_name("length")
        if output_length.size % int(downsample) != 0:
            pad_amt = int(downsample) - int(
                output_length.size % int(downsample))
            output = mtf.pad(output, [0, pad_amt], output_length.name)
        if downsample_function == "mean":
            output = mtf.pool_tensor_1d(
                output,
                pool_dim=output.shape.get_dim_by_name("length"),
                reduce_fn=mtf.reduce_mean,
                pool_size=int(downsample))
        else:
            raise ValueError("Downsampling function not implemeneted.")

    return output
    def ffn_layer_multi_inputs(self,
                               context,
                               mask,
                               inputs_list,
                               ffn_layer_type="dense",
                               kernel_initializer=None,
                               activation=None,
                               preprocess=False,
                               postprocess=False):
        """Implements a Feed-forward layer with multiple inputs, pad-removing, etc.

    Args:
      context: mtf context
      mask: mask
      inputs_list: list of input tensors
      ffn_layer_type: dense / dense_dropconnect/ dense_relu_dense
      kernel_initializer: kernel initializer
      activation: activation function
      preprocess: if preprocess the input --> default: layer-norm
      postprocess: if postprocess the output --> default: drop-out and residual

    Returns:
      a tensor
    Raises:
      ValueError: Unknown ffn_layer type.

    """

        # need at least one inputs
        num_inputs = len(inputs_list)
        assert num_inputs > 0

        if preprocess:
            # In case of having more than one input to the ffn,
            # we just apply layer norm on them independently as preprocessing
            for i, inputs in enumerate(inputs_list):
                inputs_list[i] = self._layer_norm(
                    context, (inputs * mask) if mask else inputs)

        # the output size is the hidden size of the main inputs
        ffn_inputs = inputs_list[0]
        if len(inputs_list) != 1:
            ffn_inputs = mtf.concat(inputs_list, context.model.model_dim.name)
        if ffn_layer_type == "dense":
            # last_dims = [
            #     mtf.Dimension(ffn_inputs.shape.dims[-1].name, hidden_size)
            # ]
            output = mtf.layers.dense(ffn_inputs,
                                      reduced_dims=[ffn_inputs.shape.dims[-1]],
                                      new_dims=[context.model.model_dim],
                                      activation=activation,
                                      use_bias=True,
                                      variable_dtype=context.variable_dtype,
                                      expert_dims=context.model.ensemble_dims,
                                      kernel_initializer=kernel_initializer)
        elif ffn_layer_type == "dense_relu_dense":
            output = mtf.layers.dense_relu_dense(
                ffn_inputs,
                hidden_channels=context.model.model_dim,
                is_training=context.train,
                dropout=self.relu_dropout)

        else:
            raise ValueError("Unknown ffn_layer type: %s" % ffn_layer_type)

        if postprocess:
            output = self._layer_norm(context,
                                      (output * mask) if mask else output)

        return output
    def get_timing_signal_1d(self,
                             context,
                             length,
                             channels,
                             min_timescale=1.0,
                             max_timescale=1.0e4,
                             start_index=0):
        """Gets a bunch of sinusoids of different frequencies.

    Each channel of the input Tensor is incremented by a sinusoid of a different
    frequency and phase.

    This allows attention to learn to use absolute and relative positions.
    Timing signals should be added to some precursors of both the query and the
    memory inputs to attention.

    The use of relative position is possible because sin(x+y) and cos(x+y) can
    be expressed in terms of y, sin(x) and cos(x).

    In particular, we use a geometric sequence of timescales starting with
    min_timescale and ending with max_timescale.  The number of different
    timescales is equal to channels / 2. For each timescale, we
    generate the two sinusoidal signals sin(timestep/timescale) and
    cos(timestep/timescale).  All of these sinusoids are concatenated in
    the channels dimension.

    Args:
      context: mtf context.
      length: a mtf.Dimension, length of timing signal sequence.
      channels: a mtf.Dimension, size of timing embeddings to create.
      The number of different timescales is equal to channels / 2.
      min_timescale: a float
      max_timescale: a float
      start_index: index of first position

    Returns:
      a Tensor of timing signals [1, length, channels]
    """

        position = context.get_position() + start_index
        num_timescales = mtf.constant(context.mesh, channels.size // 2)
        log_timescale_increment = (
            math.log(float(max_timescale) / float(min_timescale)) /
            mtf.maximum(num_timescales - 1, 1))
        channel_dim_name = channels.name
        inv_timescales = (min_timescale * mtf.exp(
            mtf.mtf_range(context.mesh,
                          mtf.Dimension(channel_dim_name, channels.size // 2),
                          context.activation_dtype) * -log_timescale_increment)
                          )

        scaled_time = position * inv_timescales
        # Please note that this slightly differs from the published paper.
        # See a discussion here:
        # https://github.com/tensorflow/tensor2tensor/pull/177
        #    concat_dim_name = scaled_time.shape.dimension_names[1]
        concat_dim_name = channels.name
        signal = mtf.concat(
            [mtf.sin(scaled_time), mtf.cos(scaled_time)],
            concat_dim_name=concat_dim_name)

        if channels.size % 2 != 0:
            raise NotImplementedError("Odd channel size not implemented.")
        new_dims = [mtf.Dimension("expanded", 1)
                    ] + length.shape.dims + channels.shape.dim
        signal = mtf.reshape(signal, mtf.Shape(new_dims))
        return signal
    def _call_internal(self,
                       context,
                       inputs,
                       targets=None,
                       attributes=None,
                       z=None):
        """Compute logits based on inputs (all positions in parallel).
        Also updates context if applicable.
        Args:
          context: a Context
          inputs: a Tensor
          targets: an optional Tensor
          attributes: an optional Tensor
        Returns:g
          logits: a Tensor with shape [<batch_dims>, length_dim, output_vocab_dim]
        """
        mesh = inputs.mesh
        if self.ensemble_dim and self.ensemble_dim not in inputs.shape.dims:
            # Training an ensemble where all models are trained on the same examples.
            inputs = mtf.broadcast(inputs,
                                   [self.ensemble_dim] + inputs.shape.dims)
            if self.ensemble_dim not in attributes.shape.dims:
                attributes = mtf.broadcast(attributes, [self.ensemble_dim] +
                                           attributes.shape.dims)
            if targets:
                targets = mtf.broadcast(targets, [self.ensemble_dim] +
                                        targets.shape.dims)
        if "embedding" in context.shared_params:
            vocab_embedding = context.shared_params["embedding"]
        else:
            vocab_embedding = VocabEmbedding(mesh,
                                             self.input_vocab_dim,
                                             self.model_dim,
                                             context.variable_dtype,
                                             name="embedding",
                                             ensemble_dim=self.ensemble_dim)
        x = vocab_embedding.ids_to_embedding(inputs)
        if self.positional_embedding:
            if "positional_embedding" in context.shared_params:
                pos_emb_var = context.shared_params["positional_embedding"]
            else:
                pos_emb_var = mtf.layers.embedding_weights(
                    mesh,
                    self.max_length_dim,
                    self.model_dim,
                    context.variable_dtype,
                    "positional_embedding",
                    ensemble_dim=self.ensemble_dim)
            if (context.length_dim is not None
                    and context.length_dim.size > self.max_length_dim.size):
                message = (
                    "Length dimenison exceeds size of positional embedding table. "
                    "length_dim.size > max_length_dim.size %s vs %s." %
                    (context.length_dim, self.max_length_dim))
                if context.position_is_default:
                    # Definitely getting overflow in this case.
                    raise ValueError(message)
                else:
                    tf.logging.warning(
                        message +
                        " This may be OK if there are several shorter sequences packed "
                        "together.  Otherwise, the later positions will get zeros."
                    )
            if context.position_is_default:
                pos_emb = mtf.rename_dimension(
                    mtf.slice(pos_emb_var, 0, context.length_dim.size,
                              self.max_length_dim.name),
                    self.max_length_dim.name, context.length_dim.name)
            else:
                pos_emb = mtf.gather(pos_emb_var,
                                     context.position,
                                     self.max_length_dim,
                                     output_shape=x.shape)
            x += pos_emb

        if self.attribute_embedding:
            if "attribute_embedding" in context.shared_params:
                att_emb_var = context.shared_params["attribute_embedding"]
            else:
                att_emb_var = mtf.layers.embedding_weights(
                    mesh,
                    self.attribute_dim,
                    self.model_dim,
                    context.variable_dtype,
                    "attribute_embedding",
                    ensemble_dim=self.ensemble_dim)

            att_emb = mtf.gather(att_emb_var,
                                 attributes,
                                 self.attribute_dim,
                                 output_shape=x.shape)
            # Addition of x and attribute
            # x *= LAMBDA_ATTRIBUTE * sty_emb #

            # Concatenation of x and attribute
            x_attribute = mtf.concat([x, att_emb], self.model_dim.name)
            x = mtf.layers.dense(x_attribute,
                                 self.model_dim,
                                 activation=None,
                                 variable_dtype=context.variable_dtype,
                                 name="comb_x_attribute")

        if z:
            z = mtf.layers.dense(z,
                                 self.model_dim,
                                 activation=None,
                                 variable_dtype=context.variable_dtype,
                                 name="z")
            # raise ValueError("x shape=%s , z shape=%s" % (x.shape, z.shape))
            x += z

        x = self.layer_stack.call(context, x)
        if self.output_vocab_dim is None:
            return x
        if self.shared_embedding_and_softmax_weights:
            logits = vocab_embedding.hidden_to_logits(x)
        else:
            logits = mtf.layers.dense(x,
                                      self.output_vocab_dim,
                                      use_bias=False,
                                      variable_dtype=context.variable_dtype,
                                      reduced_dims=x.shape.dims[-1:],
                                      name="logits")
        if targets is not None and context.losses is not None:
            context.losses.append(
                self._compute_loss(context, logits, targets,
                                   self.output_vocab_dim))
        if self.ensemble_dim:
            logits = reduce_ensemble_logits(logits, self.ensemble_dim,
                                            self.output_vocab_dim)
        return logits