Exemple #1
0
def widedeep(id_hldr, wt_hldr, vocab_dim, embed_dim, outdim, float16=None):
    logger.debug("[input tensor] (name,shape):({},{})".format(id_hldr.name,id_hldr.shape))
    logger.debug("[input tensor] (name,shape):({},{})".format(wt_hldr.name,wt_hldr.shape))
    if float16:
        deep_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim, variable_dtype=float16, name="deep_embedding")
    else:
        fp32 = mtf.VariableDType(tf.float32,tf.float32,tf.float32)
        deep_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim, variable_dtype=fp32, name="deep_embedding")
    logger.debug("[output tensor] (name,shape):({},{})".format(deep_output.name,deep_output.shape))
    expend_dim = mtf.Dimension('expend',size=1)
    embed_dim_one = mtf.Dimension('embed_dim_one',size=1)
    mask = mtf.reshape(wt_hldr, new_shape=[wt_hldr.shape.dims[0],wt_hldr.shape.dims[1],expend_dim], name='mask_reshape')
    logger.debug("[output tensor] (name,shape):({},{})".format(mask.name,mask.shape))
    if float16:
        wide_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim_one, variable_dtype=float16, name="wide_embedding")
    else:
        fp32 = mtf.VariableDType(tf.float32,tf.float32,tf.float32)
        wide_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim_one, variable_dtype=fp32, name="wide_embedding")
    logger.debug("[output tensor] (name,shape):({},{})".format(wide_output.name,wide_output.shape))

    wide_output = wide(wide_output,mask=mask,float16=float16)
    deep_output = deep(deep_output,mask=mask,float16=float16)
    
    result = mtf.add(wide_output,deep_output)
    result = mtf.reshape(result, new_shape=[wide_output.shape.dims[0],outdim],name='result_reshape')
    logger.debug("[output tensor] (name,shape):({},{})".format(result.name, result.shape))
    return result
Exemple #2
0
def get_variable_dtype(bf_16=True):
    # Trainable variable precision
    # Store checkpoints in master type, train in slice type, compute in activation type
    if bf_16:
        return mtf.VariableDType(master_dtype=tf.bfloat16,
                                 slice_dtype=tf.float32,
                                 activation_dtype=tf.bfloat16)
    else:
        return mtf.VariableDType(master_dtype=tf.float32,
                                 slice_dtype=tf.float32,
                                 activation_dtype=tf.float32)
Exemple #3
0
def deep(x, mask, float16=None):
    x = mtf.einsum([x, mask], output_shape=x.shape.dims, name='deep_mul')
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))

    # 使用仿照mindspore中使用fp16来计算下面的dense
    x = mtf.cast(x, dtype=tf.float16)

    x = mtf.layers.dense(x,
                         mtf.Dimension(name='dense_dim0', size=1024),
                         name="deep-dense-0",
                         reduced_dims=x.shape.dims[-2:],
                         activation=mtf.relu,
                         variable_dtype=mtf.VariableDType(
                             tf.float16, tf.float16, tf.float16))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    x = mtf.layers.dense(x,
                         mtf.Dimension(name='dense_dim1', size=512),
                         name="deep-dense-1",
                         activation=mtf.relu,
                         variable_dtype=mtf.VariableDType(
                             tf.float16, tf.float16, tf.float16))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    x = mtf.layers.dense(x,
                         mtf.Dimension(name='dense_dim2', size=256),
                         name="deep-dense-2",
                         activation=mtf.relu,
                         variable_dtype=mtf.VariableDType(
                             tf.float16, tf.float16, tf.float16))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    x = mtf.layers.dense(x,
                         mtf.Dimension(name='dense_dim3', size=128),
                         name="deep-dense-3",
                         activation=mtf.relu,
                         variable_dtype=mtf.VariableDType(
                             tf.float16, tf.float16, tf.float16))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    x = mtf.layers.dense(x,
                         mtf.Dimension(name='dense_dim4', size=1),
                         name="deep-dense-4",
                         variable_dtype=mtf.VariableDType(
                             tf.float16, tf.float16, tf.float16))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    if float16:
        pass
    else:
        x = mtf.cast(x, dtype=tf.float32)
    return x
Exemple #4
0
def test_sampling():
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")

    batch_dim = mtf.Dimension("batch", 1)
    sequence_dim = mtf.Dimension("sequence", 1)

    inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32)
    inputs = mtf.pad(inputs, [0, 3], sequence_dim.name)

    # create mask

    seq_len = params["n_ctx"]
    num_mem_kv = params.get('num_mem_kv', 0)
    length_dim = mtf.Dimension('sequence', seq_len)
    memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv)
    embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len)
    embd_dim = mtf.Dimension("embd", params["n_embd"])
    vocab_dim = mtf.Dimension("vocab", params["n_vocab"])

    other_features = {}

    other_features["attn_bias"] = biasmask_attn_weights(
        mesh, length_dim, memory_length_dim, mtf.VariableDType(tf.float32))
    other_features["embd_dim"] = embd_dim
    other_features["vocab_dim"] = vocab_dim
    other_features["embed_sequence_dim"] = embed_sequence_dim
    other_features["memory_length_dim"] = memory_length_dim

    params["mode"] = "predict"

    with not_raises(Exception):
        samples = sample_autoregressive(
            inputs,
            other_features=other_features,
            params=params,
            variable_dtype=mtf.VariableDType(),
            remove_partial_sequences=params["remove_partial_sequences"],
            stop_at_token=params["eos_id"],
            sampling_use_entmax=True)

        mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                          layout={},
                                                          devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        samples = lowering.export_to_tf_tensor(samples)
Exemple #5
0
    def call_simple(self,
                    inputs,
                    targets,
                    compute_loss,
                    mode=tf.estimator.ModeKeys.TRAIN,
                    variable_dtype=mtf.VariableDType(tf.float32),
                    encoder_sequence_id=None,
                    decoder_sequence_id=None):
        """Compute logits based on inputs (all positions in parallel).

    This is called during training and evaluation.

    Args:
      inputs: an int32 Tensor with shape [<batch_dims>, length_dim]
      targets: an optional int32 Tensor with shape [<batch_dims>, length_dim]
      compute_loss: a boolean
      mode: a tf.estimator.ModeKeys
      variable_dtype: a mtf.VariableDType
      encoder_sequence_id: an optional Tensor
      decoder_sequence_id: an optional Tensor

    Returns:
      logits: a Tensor with shape [<batch_dims>, output_vocab_dim]
      loss: an optional Scalar (if compute_loss=True)
    """
        shared_params = self._shared_params(inputs.mesh, variable_dtype)
        encoder_output, encoder_loss = self.encoder.call_simple(
            inputs,
            None,
            compute_loss,
            mode=mode,
            variable_dtype=variable_dtype,
            sequence_id=encoder_sequence_id,
            shared_params=shared_params)
        encoder_output = mtf.layers.rename_length_to_memory_length(
            encoder_output)
        if encoder_sequence_id is not None:
            encoder_sequence_id = mtf.layers.rename_length_to_memory_length(
                encoder_sequence_id)
        length_dim = targets.shape.dims[-1]
        shifted_targets = mtf.shift(targets,
                                    offset=1,
                                    dim=length_dim,
                                    wrap=False)
        logits, loss = self.decoder.call_simple(
            shifted_targets,
            targets,
            compute_loss,
            mode=mode,
            variable_dtype=variable_dtype,
            sequence_id=decoder_sequence_id,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            shared_params=shared_params)
        if loss is not None and encoder_loss is not None:
            loss += encoder_loss
        return logits, loss
Exemple #6
0
    def call_simple(self,
                    inputs,
                    targets,
                    compute_loss,
                    mode=tf.estimator.ModeKeys.TRAIN,
                    variable_dtype=mtf.VariableDType(tf.float32),
                    sequence_id=None,
                    position=None,
                    encoder_output=None,
                    encoder_sequence_id=None,
                    shared_params=None):
        """Compute logits based on inputs (all positions in parallel).

    This is called during training and evaluation.

    Args:
      inputs: an int32 Tensor with shape [<batch_dims>, length_dim]
        For training autoregressive models this should be equal to
        mtf.shift(targets, offset=1, dim=length_dim, wrap=False)
      targets: an optional int32 Tensor with shape [<batch_dims>, length_dim]
      compute_loss: a boolean
      mode: a tf.estimator.ModeKeys
      variable_dtype: a mtf.VariableDType
      sequence_id: an optional Tensor
      position: an optional Tensor
      encoder_output: an optional Tensor
      encoder_sequence_id: an optional Tensor
      shared_params: an optional dictionary

    Returns:
      logits: a Tensor with shape [<batch_dims>, output_vocab_dim]
      loss: an optional Scalar (if compute_loss=True)
    """
        context = Context(mesh=inputs.mesh,
                          batch_dims=inputs.shape.dims[:-1],
                          length_dim=inputs.shape.dims[-1],
                          model_dim=self.model_dim,
                          variable_dtype=variable_dtype,
                          mode=mode,
                          autoregressive=self.autoregressive,
                          losses=[] if compute_loss else None,
                          sequence_id=sequence_id,
                          position=position,
                          encoder_output=encoder_output,
                          encoder_sequence_id=encoder_sequence_id,
                          shared_params=shared_params,
                          layout=self.layout,
                          mesh_shape=self.mesh_shape)
        with tf.variable_scope(self.name):
            logits = self._call_internal(context, inputs, targets)
        if compute_loss:
            loss = mtf.add_n(context.losses)
        else:
            loss = None
        return logits, loss
Exemple #7
0
    def setUp(self):
        super(FlatKeyValueMemoryTest, self).setUp()
        self.graph = mtf.Graph()
        self.mesh = mtf.Mesh(self.graph, "mtf_mesh")
        self.variable_dtype = mtf.VariableDType(activation_dtype=tf.float32)

        self.addCleanup(mock.patch.stopall)
        self.initializer_mock = mock.MagicMock()
        random_normal_initializer_mock = mock.patch.object(
            tf, "random_normal_initializer").start()
        random_normal_initializer_mock.return_value = self.initializer_mock
Exemple #8
0
    def setUp(self):
        super(AdaptiveSoftmaxTest, self).setUp()
        self.graph = mtf.Graph()
        self.mesh = mtf.Mesh(self.graph, 'mtf_mesh')
        self.variable_dtype = mtf.VariableDType(activation_dtype=tf.float32)

        self.addCleanup(mock.patch.stopall)
        self.initializer_mock = mock.MagicMock()
        random_normal_initializer_mock = mock.patch.object(
            tf, 'random_normal_initializer').start()
        random_normal_initializer_mock.return_value = self.initializer_mock
Exemple #9
0
def model_backbone(features, labels, mesh):
    """The model.
	Args:
		image: tf.Tensor with shape [batch, 32*32]
		labels: a tf.Tensor with shape [batch] and dtype tf.int32
		mesh: a mtf.Mesh
	Returns:
		logits: a mtf.Tensor with shape [batch, 10]
		loss: a mtf.Tensor with shape []
	"""
    id_hldr, wt_hldr = features

    batch_dim = mtf.Dimension("batch", args_opt.batch_size)
    field_dim = mtf.Dimension("field", size=39)
    vocab_dim = mtf.Dimension("vocab_size", 200000)
    embed_dim = mtf.Dimension("embed_size", 80)
    outdim = mtf.Dimension("outdim", 1)
    id_hldr = mtf.import_tf_tensor(
        mesh, tf.reshape(id_hldr, [args_opt.batch_size, field_dim.size]),
        mtf.Shape([batch_dim, field_dim]))
    wt_hldr = mtf.import_tf_tensor(
        mesh, tf.reshape(wt_hldr, [args_opt.batch_size, field_dim.size]),
        mtf.Shape([batch_dim, field_dim]))
    if args_opt.fp16:
        float16 = mtf.VariableDType(tf.float16, tf.float16, tf.float16)
        # id_hldr=mtf.cast(id_hldr,dtype=tf.int32)
        wt_hldr = mtf.cast(wt_hldr, dtype=tf.float16)
    else:
        float16 = None

    logits, embedding_table = network[args_opt.model](id_hldr,
                                                      wt_hldr,
                                                      vocab_dim,
                                                      embed_dim,
                                                      outdim,
                                                      float16=float16)
    logits = mtf.cast(logits, dtype=tf.float32)
    embedding_table = mtf.cast(embedding_table, dtype=tf.float32)
    if labels is None:
        wide_loss = None
        deep_loss = None
    else:
        labels = mtf.import_tf_tensor(
            mesh, tf.reshape(labels, [args_opt.batch_size]),
            mtf.Shape([batch_dim]))
        wide_loss = mtf.layers.sigmoid_cross_entropy_with_logits(
            logits, labels)
        deep_loss = mtf.reduce_mean(mtf.square(embedding_table)) / 2
        deep_loss = mtf.reduce_mean(wide_loss) + 8e-5 * deep_loss
        wide_loss = mtf.reduce_mean(wide_loss)

    return logits, wide_loss + deep_loss
def get_dummy_decoder_context(converter,
                              batch=2,
                              d_model=6,
                              length=4,
                              mode="incremental",
                              initial_position=None,
                              state=None,
                              inputs=None):

    batch_dim = mtf.Dimension("batch", batch)
    length_dim = mtf.Dimension("length", length)

    # Set up a dummy model
    layer_stack = transformer.LayerStack(layers=[])
    model = transformer.Unitransformer(
        d_model=d_model,
        input_vocab_size=10,  # dummy values
        output_vocab_size=10,  # dummy values
        autoregressive=True,
        max_length=length,
        layer_stack=layer_stack)

    if state is not None:
        state_mtf = converter.convert_np_array_to_mtf_tensor(
            state, dtype=tf.float32, dim_names=["batch", "length", "d_model"])
        states = [state_mtf]
    else:
        states = None

    if initial_position:
        initial_position = mtf.constant(converter.mesh,
                                        initial_position,
                                        shape=mtf.Shape([batch_dim]),
                                        dtype=tf.int32)

    if inputs is not None:
        inputs = converter.convert_np_array_to_mtf_tensor(
            inputs, dim_names=["batch", "length"])

    context = transformer.Context(model=model,
                                  mode=mode,
                                  states=states,
                                  new_states=[],
                                  mesh=converter.mesh,
                                  batch_dims=[batch_dim],
                                  length_dim=length_dim,
                                  variable_dtype=mtf.VariableDType(tf.float32),
                                  sequence_id=1,
                                  inputs=inputs,
                                  initial_position=initial_position)
    return context
Exemple #11
0
def test_model():
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")

    seq_len = params["n_ctx"]

    batch_dim = mtf.Dimension("batch", 1)
    sequence_dim = mtf.Dimension("sequence", seq_len)

    features = {
        'inputs': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)),
                           tf.int32),
        'labels': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)),
                           tf.int32)
    }

    # create mask

    num_mem_kv = params.get('num_mem_kv', 0)
    length_dim = mtf.Dimension('sequence', seq_len)
    memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv)
    embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len)
    embd_dim = mtf.Dimension("embd", params["n_embd"])
    vocab_dim = mtf.Dimension("vocab", params["n_vocab"])

    other_features = {}
    variable_dtype = mtf.VariableDType(tf.float32, tf.float32, tf.float32)

    other_features["attn_bias"] = biasmask_attn_weights(
        mesh, length_dim, memory_length_dim, variable_dtype)
    other_features["embd_dim"] = embd_dim
    other_features["vocab_dim"] = vocab_dim
    other_features["embed_sequence_dim"] = embed_sequence_dim
    other_features["memory_length_dim"] = memory_length_dim

    with not_raises(Exception):
        logits, _, _ = gpt2.model(features,
                                  other_features,
                                  params,
                                  mesh,
                                  variable_dtype=variable_dtype)

        mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                          layout={},
                                                          devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        logits = lowering.export_to_tf_tensor(logits)
Exemple #12
0
def get_variable_dtype(master_dtype=tf.bfloat16,
                       slice_dtype=tf.float32,
                       activation_dtype=tf.float32):
    """Datatypes to use for the run.

  Args:
    master_dtype: string, datatype for checkpoints
      keep this the same between training and eval/inference
    slice_dtype: string, datatype for variables in memory
      must be tf.float32 for training
    activation_dtype: string, datatype for activations
      less memory usage if tf.bfloat16 but possible numerical issues
  Returns:
    a mtf.VariableDtype
  """
    return mtf.VariableDType(master_dtype=tf.as_dtype(master_dtype),
                             slice_dtype=tf.as_dtype(slice_dtype),
                             activation_dtype=tf.as_dtype(activation_dtype))
Exemple #13
0
    def testVariableOperations(self):
        var = mtf.Variable(self.mesh,
                           "test_variable",
                           self.ab_shape,
                           mtf.VariableDType(tf.int32, tf.int32, tf.int32),
                           initializer=tf.zeros_initializer(),
                           trainable=True)

        self.assertEqual(var.splittable_dims, frozenset(["a", "b"]))
        self.assertEqual(var.unsplittable_dims, frozenset())

        read_variable = mtf.ReadVariable(var)
        self.assertEqual(read_variable.splittable_dims, frozenset(["a", "b"]))
        self.assertEqual(read_variable.unsplittable_dims, frozenset())

        assign = mtf.Assign([var], [self.x])
        self.assertEqual(assign.splittable_dims, frozenset(["a", "b"]))
        self.assertEqual(assign.unsplittable_dims, frozenset())

        depend = mtf.Depend(read_variable.outputs[0], [assign])
        self.assertEqual(depend.splittable_dims, frozenset(["a", "b"]))
        self.assertEqual(depend.unsplittable_dims, frozenset())
Exemple #14
0
def model_backbone(image, labels, mesh):
	"""The model.
	Args:
		image: tf.Tensor with shape [batch, 32*32]
		labels: a tf.Tensor with shape [batch] and dtype tf.int32
		mesh: a mtf.Mesh
	Returns:
		logits: a mtf.Tensor with shape [batch, 10]
		loss: a mtf.Tensor with shape []
	"""
	batch_dim = mtf.Dimension("batch", args_opt.batch_size)
	rows_dim = mtf.Dimension("rows_size", 224)
	cols_dim = mtf.Dimension("cols_size", 224)
	channel_dim = mtf.Dimension("image_channel", 3)
	classes_dim = mtf.Dimension(name='classesnum',size=args_opt.class_num)
	x = mtf.import_tf_tensor(
		mesh, tf.reshape(image, [args_opt.batch_size, 224, 224, 3]),
		mtf.Shape(
			[batch_dim, rows_dim, cols_dim, channel_dim]))
	if args_opt.fp16:
		float16=mtf.VariableDType(tf.float16,tf.float16,tf.float16)
	else:
		float16=None

	logits = network[args_opt.model](x, classes_dim=classes_dim,float16=float16,batch_norm=False if 'vgg' in args_opt.model else True)
	logits = mtf.cast(logits,dtype=tf.float32)

	if labels is None:
		loss = None
	else:
		labels = mtf.import_tf_tensor(
			mesh, tf.reshape(labels, [args_opt.batch_size]), mtf.Shape([batch_dim]))
		loss = mtf.layers.softmax_cross_entropy_with_logits(
			logits, mtf.one_hot(labels, classes_dim), classes_dim)
		loss = mtf.reduce_mean(loss)
	return logits, loss
Exemple #15
0
def model_fn(features, labels, mode, params):
    # Get global step
    global_step = tf.train.get_global_step()

    # Construct mtf graph + mesh from params
    graph = mtf.Graph()
    mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
    layout_rules = mtf.convert_to_layout_rules(params["layout"])

    # Mesh setup
    if params["use_tpu"]:
        var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape,
                                                layout_rules)
    else:
        var_placer = None
        gpu_ids = params["gpu_ids"]
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
            mesh_shape, layout_rules, gpu_ids)

    # Trainable variable precision
    # Store to checkpoints in master type, train in slice type, compute in activation type
    if params["precision"] == "bfloat16":
        variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16,
                                           slice_dtype=tf.float32,
                                           activation_dtype=tf.bfloat16)
    else:
        variable_dtype = mtf.VariableDType(master_dtype=tf.float32,
                                           slice_dtype=tf.float32,
                                           activation_dtype=tf.float32)

    # Build mtf mesh object
    mesh = mtf.Mesh(graph, "my_mesh", var_placer)

    # Build mtf_features & seq length dict for getting number of microbatches
    # We need to pack inputs into a dict to pass into serialize_training_step
    features_dict = {"inputs": features, "labels": labels}
    sequence_length_dict = {
        "inputs": params["n_ctx"],
        "labels": params["n_ctx"]
    }

    params = add_mode_to_params(params, mode)
    batch_size = get_batch_size(params)

    batch_dim = mtf.Dimension("batch", batch_size)
    batch_dims = [batch_dim]
    feature_length = sequence_length_dict["inputs"]
    length_dim = mtf.Dimension("sequence", feature_length)

    mtf_features = {}
    for key, x in features_dict.items():
        if x is not None:
            feature_shape = mtf.Shape(batch_dims + [length_dim])
            if type(features_dict[key]) == dict:
                features_dict[key] = features_dict[key]["feature"]
            x = tf.cast(features_dict[key], tf.int32)
            x = tf.reshape(x, feature_shape.to_integer_list)
            mtf_features[key] = mtf.import_fully_replicated(mesh,
                                                            x,
                                                            feature_shape,
                                                            name=key)

    # Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model
    other_features = {}
    memory_length_dim = mtf.Dimension("memory_length", length_dim.size)

    attn_bias = biasmask_attn_weights(
        mesh, length_dim, memory_length_dim,
        variable_dtype) if params["causal"] else None

    # Add attn_bias into mtf_features
    other_features["attn_bias"] = attn_bias

    # Define other Dimensions that we'll need inside the model
    embd_dim = mtf.Dimension("embd", params["n_embd"])
    vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
    # We need this because gathering when both the args have the same dimension in them breaks things
    # This dim is specifically for the weights
    # This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error
    embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"])

    other_features["embd_dim"] = embd_dim
    other_features["vocab_dim"] = vocab_dim
    other_features["embed_sequence_dim"] = embed_sequence_dim
    other_features["memory_length_dim"] = memory_length_dim

    if mode == tf.estimator.ModeKeys.PREDICT:
        # Set up the model for prediction
        inputs = mtf_features["inputs"]
        if params["remove_partial_sequences"] is None:
            params["remove_partial_sequences"] = False

        export = params.get("export", False)

        if not export:
            mtf_samples = sample_autoregressive(
                inputs,
                other_features=other_features,
                params=params,
                variable_dtype=variable_dtype,
                remove_partial_sequences=params["remove_partial_sequences"],
                stop_at_token=params["eos_id"],
                sampling_use_entmax=params['sampling_use_entmax'])

        else:
            with mtf.utils.outside_all_rewrites():
                with tf.variable_scope('gpt2'):
                    mtf_samples, loss, loss_batch = gpt2.model(
                        mtf_features,
                        other_features,
                        params,
                        mesh,
                        variable_dtype=variable_dtype,
                        context=None)

        mtf_samples = mtf.anonymize(mtf_samples)
        inputs = mtf.anonymize(inputs)
        lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
        inputs = lowering.export_to_tf_tensor(inputs)
        outputs = lowering.export_to_tf_tensor(mtf_samples)
        predictions = {"inputs": inputs, "outputs": outputs}

        def scaffold_fn():
            return tf.train.Scaffold(
                local_init_op=tf.group(
                    tf.train.Scaffold.default_local_init_op(),
                    lowering.copy_masters_to_slices(),
                    name="mtf_local_init_op"),
                ready_op=tf.concat([
                    tf.report_uninitialized_variables(),
                    resources.report_uninitialized_resources()
                ],
                                   axis=0,
                                   name="mtf_ready_op"))

        return tpu_estimator.TPUEstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            scaffold_fn=scaffold_fn,
            prediction_hooks=[mtf.MtfRestoreHook(lowering)])

    # We're not predicting, so we better be training or evaluating
    assert (mode == tf.estimator.ModeKeys.TRAIN
            or mode == tf.estimator.ModeKeys.EVAL)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # Gets number of microbatches per batch for serialized training
        # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed
        num_microbatches = int(
            mtf_transformer.utils.serialize_num_microbatches(
                batch_dim=batch_dim,
                sequence_length=sequence_length_dict,
                mesh_shape=mesh_shape,
                layout_rules=layout_rules,
                tokens_per_microbatch_per_replica=params[
                    "tokens_per_mb_per_replica"]))
    else:
        num_microbatches = 1

    params[
        "num_microbatches"] = num_microbatches  # Add num microbatches to params

    if num_microbatches > 1:

        # For serialize_training_step we need to modify the model to output results in a dict
        def serialized_fn(mtf_features):
            if params["model"] == "GPT":
                with tf.variable_scope('gpt2'):
                    logits, loss, loss_batch = gpt2.model(
                        mtf_features,
                        other_features,
                        params,
                        mesh,
                        variable_dtype=variable_dtype)
                return {
                    "logits": logits,
                    "loss": loss,
                    "loss_batch": loss_batch
                }
            else:
                raise Exception(
                    f"'{params['model']}' is not a valid model - please select from [GPT]"
                )

        # Serialize the training step - Gradients are accumulated locally and reduced once.
        var_grads, output_dict = mtf.serialize_training_step(
            mtf_features, serialized_fn, batch_dim, num_microbatches)
        loss = output_dict["loss"]
        loss_batch = output_dict["loss_batch"]
        logits = output_dict["logits"]
    else:
        # If we're not splitting into microbatches, return logits & loss as is
        if params["model"] == "GPT":
            with mtf.utils.outside_all_rewrites():
                with tf.variable_scope('gpt2'):
                    logits, loss, loss_batch = gpt2.model(
                        mtf_features,
                        other_features,
                        params,
                        mesh,
                        variable_dtype=variable_dtype,
                        context=None)
        else:
            raise Exception(
                f"'{params['model']}' is not a valid model - please select from [GPT]"
            )

    # Auto layout generation
    if params["auto_layout"]:
        auto_layout(graph, mesh_shape, logits, loss)
    if params["auto_layout_and_mesh_shape"]:
        auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # In TRAIN mode, get optimizer
        if params["num_microbatches"] > 1:
            # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn
            # So we pass them in here
            _, update_ops, var_grads = get_optimizer(
                mesh,
                loss,
                params,
                variable_dtype=variable_dtype,
                inp_var_grads=var_grads)
        else:
            # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank
            _, update_ops, var_grads = get_optimizer(
                mesh, loss, params, variable_dtype=variable_dtype)
        # Log summaries to tensorboard
        mtf.scalar_summary("loss", loss)
        # Log gradients if in params
        if params["log_grads"] not in [None, False]:
            for g in var_grads:
                grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g)))
                mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm)
    else:
        # For now, we can only export fully-replicated tensors.
        # This has to be done before lowering or they will not be included in the graph
        mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim)
        max_logits = mtf.argmax(logits, vocab_dim)
        del logits
        fully_replicated_mean_logits = mtf.anonymize(mean_logits)
        fully_replicated_max_logits = mtf.anonymize(max_logits)
        fully_replicated_loss_batch = mtf.anonymize(loss_batch)

    # Gets & prints info about no. trainable vars in the model & dimension names
    get_graph_info(graph)

    # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors
    lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
    tf_loss = lowering.export_to_tf_tensor(loss)
    tf_loss = tf.cast(tf_loss, tf.float32)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # Use our patched version until mtf updates theirs
        host_call = create_host_call(params['model_path'])
        mtf.utils.remove_summaries()

        # Creates train_op
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(
            global_step, 1))  # Need to manually increment global_step
        tf.logging.info(f"tf_update_ops: {tf_update_ops}")
        train_op = tf.group(tf_update_ops)
    else:
        tf_mean_logits = lowering.export_to_tf_tensor(
            fully_replicated_mean_logits)
        tf_max_logits = lowering.export_to_tf_tensor(
            fully_replicated_max_logits)
        tf_loss_batch = tf.to_float(
            lowering.export_to_tf_tensor(fully_replicated_loss_batch))

    with mtf.utils.outside_all_rewrites():
        # Copy master variables to slices. Must be called first.
        restore_hook = mtf.MtfRestoreHook(lowering)
        if mode == tf.estimator.ModeKeys.TRAIN:
            # Set up the checkpoint server and return the TPUEstimatorSpec
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                params["model_path"],
                save_steps=params["steps_per_checkpoint"],
                saver=saver,
                listeners=[saver_listener])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                host_call=host_call,
                train_op=train_op,
                training_hooks=[restore_hook, saver_hook])

        elif mode == tf.estimator.ModeKeys.EVAL:
            # Evaluation metrics
            def _perplexity(loss):
                perplexity = tf.exp(loss)
                return tf.metrics.mean(perplexity)

            def _bits_per_byte(loss):
                bpb = loss * (0.29335 / math.log(2))
                return tf.metrics.mean(bpb)

            def _metric_fn(tf_mean_logits, tf_loss_batch):
                mean_logits = tf.metrics.mean(tf_mean_logits)
                loss = tf.reduce_mean(tf_loss_batch)
                perp = _perplexity(loss)
                bpb = _bits_per_byte(loss)
                return {
                    "mean_logits": mean_logits,
                    "perplexity": perp,
                    "bits per byte": bpb
                }

            def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch):
                eos_token = params["eos_id"]
                answer_positions = tf.where(
                    tf.math.not_equal(labels, eos_token))

                correct_answers = tf.gather_nd(
                    tf.math.equal(tf_max_logits, labels), answer_positions)
                accuracy = tf.metrics.mean(tf.cast(correct_answers,
                                                   tf.float32))

                # I guess tf_loss_batch has z_loss and maybe other stuff added to it
                # so maybe this should be calculated separately in the future
                answer_loss = tf.gather_nd(tf_loss_batch, answer_positions)
                log_perplexity = tf.metrics.mean(answer_loss)

                return {
                    "lambada_acc": accuracy,
                    "lambada_log_ppl": log_perplexity
                }

            eval_task = params["eval_task"]
            if eval_task == "lambada":
                eval_metrics = (_lambada_metric_fn,
                                [labels, tf_max_logits, tf_loss_batch])
            else:
                eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=tf_loss,
                eval_metrics=eval_metrics)
Exemple #16
0
def unet_with_spatial_partition(mesh, mesh_impl, dataset_str, images, labels):
    """Builds the UNet model graph, train op and eval metrics.

  Args:
    mesh: a MeshTensorflow.mesh object.
    mesh_impl: a mesh implementation, such as SimdMeshImpl and
      PlacementMeshImpl.
    dataset_str: a string of either train or eval. This is used for batch_norm.
    images: a laid out Tensor with shape [batch, x, y, num_channels]
      or [batch, x, y, z, num_channels].
    labels: a laid out Tensor with shape [batch, x, y, num_classes]
      or [batch, x, y, z, num_classes].

  Returns:
    Prediction and loss.
  """

    is_training = (dataset_str == 'train')
    if dataset_str == 'train':
        batch_dim = mtf.Dimension('batch', FLAGS.batch_size_train)
    else:
        assert dataset_str == 'eval'
        batch_dim = mtf.Dimension('batch', FLAGS.batch_size_eval)
    image_nx_dim = mtf.Dimension('image_nx_block', FLAGS.image_nx_block)
    image_ny_dim = mtf.Dimension('image_ny_block', FLAGS.image_ny_block)
    image_sx_dim = mtf.Dimension('image_sx_block',
                                 FLAGS.ct_resolution // FLAGS.image_nx_block)
    image_sy_dim = mtf.Dimension('image_sy_block',
                                 FLAGS.ct_resolution // FLAGS.image_ny_block)
    image_sz_dim = mtf.Dimension('image_sz_block', FLAGS.ct_resolution)
    image_c_dim = mtf.Dimension('image_c', FLAGS.image_c)
    label_c_dim = mtf.Dimension('label_c', FLAGS.label_c)
    mtf_images_shape, mtf_labels_shape = get_input_mtf_shapes(dataset_str)

    mtf_dtype = tf.as_dtype(FLAGS.mtf_dtype)
    variable_dtype = mtf.VariableDType(mtf_dtype, mtf_dtype, mtf_dtype)

    # Import input features.
    x = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(images),
                                   mtf_images_shape)
    x = mtf.cast(x, mtf_dtype)

    # Import ground truth labels.
    t = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(labels),
                                   mtf_labels_shape)
    t = mtf.cast(t, mtf_dtype)

    # Transpose the blocks.
    if FLAGS.sampled_2d_slices:
        x = mtf.transpose(x, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            image_c_dim
        ])

        t = mtf.transpose(t, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            label_c_dim
        ])
    else:
        x = mtf.transpose(x, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            image_sz_dim, image_c_dim
        ])

        t = mtf.transpose(t, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            image_sz_dim, label_c_dim
        ])

    # Network.
    levels = []
    all_bn_update_ops = []
    # add levels with convolution or down-sampling
    for depth in range(FLAGS.network_depth):
        for n_conv in range(FLAGS.n_conv_per_block):
            if depth == 0 and n_conv == 0:
                # no dropout in 1st layer.
                dropout_keep_p = 1.0
            else:
                dropout_keep_p = FLAGS.dropout_keep_p
            x, bn_update_ops = conv_with_spatial_partition(
                x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
                FLAGS.n_base_filters * (2**depth), dropout_keep_p,
                FLAGS.with_batch_norm, is_training,
                'conv_{}_{}'.format(depth, n_conv), variable_dtype,
                'conv_down_{}_{}'.format(depth, n_conv))
            all_bn_update_ops.extend(bn_update_ops)
        levels.append(x)

        if depth < FLAGS.network_depth - 1:
            if FLAGS.sampled_2d_slices:
                x = mtf.layers.max_pool2d(x, ksize=(2, 2))
            else:
                x = mtf.layers.max_pool3d(x, ksize=(2, 2, 2))

    # add levels with up-convolution or up-sampling
    for depth in range(FLAGS.network_depth - 1)[::-1]:
        x = deconv_with_spatial_partition(
            x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
            FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p,
            'conv_{}_{}'.format(depth, FLAGS.n_conv_per_block - 1),
            variable_dtype, 'deconv_{}_0'.format(depth))
        x = mtf.concat([x, levels[depth]],
                       concat_dim_name='conv_{}_{}'.format(
                           depth, FLAGS.n_conv_per_block - 1))

        for n_conv in range(FLAGS.n_conv_per_block):
            x, bn_update_ops = conv_with_spatial_partition(
                x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
                FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p,
                FLAGS.with_batch_norm, is_training,
                'conv_{}_{}'.format(depth, n_conv), variable_dtype,
                'conv_up_{}_{}'.format(depth, n_conv))
            all_bn_update_ops.extend(bn_update_ops)

    # no dropout in the final layer.
    if FLAGS.sampled_2d_slices:
        y = mtf.layers.conv2d_with_blocks(
            x,
            mtf.Dimension('label_c', FLAGS.label_c),
            filter_size=(1, 1),
            strides=(1, 1),
            padding='SAME',
            h_blocks_dim=image_nx_dim,
            w_blocks_dim=image_ny_dim,
            variable_dtype=variable_dtype,
            name='final_conv_{}'.format(FLAGS.label_c),
        )
    else:
        y = mtf.layers.conv3d_with_blocks(
            x,
            mtf.Dimension('label_c', FLAGS.label_c),
            filter_size=(1, 1, 1),
            strides=(1, 1, 1),
            padding='SAME',
            d_blocks_dim=image_nx_dim,
            h_blocks_dim=image_ny_dim,
            variable_dtype=variable_dtype,
            name='final_conv_{}'.format(FLAGS.label_c),
        )

    # use mtf.constant to make sure there is no CPU-side constants.
    def scalar(v, dtype):
        return mtf.constant(mesh, v, shape=[], dtype=dtype)

    argmax_t = mtf.argmax(t, label_c_dim)
    liver_t = mtf.cast(mtf.equal(argmax_t, scalar(1, tf.int32)), mtf_dtype)
    lesion_t = mtf.cast(mtf.equal(argmax_t, scalar(2, tf.int32)), mtf_dtype)

    argmax_y = mtf.argmax(y, label_c_dim)
    lesion_y = mtf.cast(mtf.equal(argmax_y, scalar(2, tf.int32)), mtf_dtype)

    # summary of class ratios.
    lesion_pred_ratio = mtf.reduce_mean(lesion_y)
    lesion_label_ratio = mtf.reduce_mean(lesion_t)

    # summary of accuracy.
    accuracy = mtf.reduce_mean(
        mtf.cast(mtf.equal(argmax_y, argmax_t), mtf_dtype))

    # Cross-entropy loss. Up-weight the liver region.
    pixel_loss = mtf.layers.softmax_cross_entropy_with_logits(
        y, t, label_c_dim)
    pixel_weight = scalar(1, mtf_dtype) + \
        liver_t * scalar(FLAGS.xen_liver_weight - 1, mtf_dtype) + \
        lesion_t * scalar(FLAGS.xen_lesion_weight - FLAGS.xen_liver_weight,
                          mtf_dtype)
    loss_xen = mtf.reduce_mean(pixel_loss * pixel_weight)

    # Dice loss
    y_prob = mtf.softmax(y, reduced_dim=label_c_dim)
    lesion_prob = mtf.reduce_sum(mtf.slice(y_prob, 2, 1, 'label_c'),
                                 reduced_dim=mtf.Dimension('label_c', 1))
    prob_intersect = mtf.reduce_sum(lesion_prob * lesion_t,
                                    output_shape=mtf.Shape([batch_dim]))
    prob_area_sum = mtf.reduce_sum(lesion_prob + lesion_t,
                                   output_shape=mtf.Shape([batch_dim]))
    loss_dice_per_case = mtf.reduce_mean(
        scalar(-2, mtf_dtype) * prob_intersect /
        (prob_area_sum + scalar(FLAGS.dice_epsilon, mtf_dtype)))
    loss_dice_global = scalar(-2, mtf_dtype) * mtf.reduce_sum(
        prob_intersect) / (mtf.reduce_sum(prob_area_sum) +
                           scalar(FLAGS.dice_epsilon, mtf_dtype))

    loss_dice = (loss_dice_per_case + loss_dice_global) * scalar(
        0.5, mtf_dtype)

    loss = scalar(FLAGS.dice_loss_weight, mtf_dtype) * loss_dice + scalar(
        1 - FLAGS.dice_loss_weight, mtf_dtype) * loss_xen

    intersect = mtf.reduce_sum(lesion_y * lesion_t,
                               output_shape=mtf.Shape([batch_dim]))
    area_sum = mtf.reduce_sum(lesion_y + lesion_t,
                              output_shape=mtf.Shape([batch_dim]))
    # summary of dice.
    dice_per_case = mtf.reduce_mean(
        scalar(2, mtf_dtype) * intersect /
        (area_sum + scalar(0.000001, mtf_dtype)))
    dice_global = scalar(2, mtf_dtype) * mtf.reduce_sum(intersect) / (
        mtf.reduce_sum(area_sum) + scalar(0.000001, mtf_dtype))

    eval_metrics = {
        'lesion_pred_ratio': lesion_pred_ratio,
        'lesion_label_ratio': lesion_label_ratio,
        'accuracy_of_all_classes': accuracy,
        'lesion_dice_per_case': dice_per_case,
        'lesion_dice_global': dice_global,
        'loss_xen': loss_xen,
        'loss_dice': loss_dice,
        'loss_dice_per_case': loss_dice_per_case,
        'loss_dice_global': loss_dice_global,
    }

    if FLAGS.sampled_2d_slices:
        y_prob_downsampled = mtf.layers.avg_pool2d(
            y_prob, ksize=(FLAGS.pred_downsample, ) * 2)
        if FLAGS.output_ground_truth:
            lesion_gt_downsampled = mtf.layers.avg_pool2d(
                mtf.slice(t, 2, 1, 'label_c'),
                ksize=(FLAGS.pred_downsample, ) * 2)
    else:
        y_prob_downsampled = mtf.layers.avg_pool3d(
            y_prob, ksize=(FLAGS.pred_downsample, ) * 3)
        if FLAGS.output_ground_truth:
            lesion_gt_downsampled = mtf.layers.avg_pool3d(
                mtf.slice(t, 2, 1, 'label_c'),
                ksize=(FLAGS.pred_downsample, ) * 3)

    liver_prob_downsampled = mtf.slice(y_prob_downsampled, 1, 1, 'label_c')
    lesion_prob_downsampled = mtf.slice(y_prob_downsampled, 2, 1, 'label_c')
    preds = [
        mtf.reduce_sum(liver_prob_downsampled,
                       reduced_dim=mtf.Dimension('label_c', 1)),
        mtf.reduce_sum(lesion_prob_downsampled,
                       reduced_dim=mtf.Dimension('label_c', 1))
    ]

    if FLAGS.output_ground_truth:
        preds.append(
            mtf.reduce_sum(lesion_gt_downsampled,
                           reduced_dim=mtf.Dimension('label_c', 1)))

    preds.extend([intersect, area_sum])

    return preds, loss, eval_metrics, all_bn_update_ops
Exemple #17
0
 def variable_dtype(self):
     return mtf.VariableDType(tf.as_dtype(self._hparams.master_dtype),
                              tf.as_dtype(self._hparams.slice_dtype),
                              tf.as_dtype(self._hparams.activation_dtype))
Exemple #18
0
    def call_simple(self,
                    inputs,
                    targets,
                    compute_loss,
                    mode=tf.estimator.ModeKeys.TRAIN,
                    variable_dtype=mtf.VariableDType(tf.float32),
                    encoder_sequence_id=None,
                    decoder_sequence_id=None,
                    decoder_subsequence_id=None,
                    encoder_position=None,
                    decoder_position=None,
                    num_microbatches=1):
        """Compute logits based on inputs (all positions in parallel).

    This is called during training and evaluation.

    This class inherits the trnasformer.Bitransformer with one difference. The
    encoder is Funnel Transformer. So the length dimension is reduced. The
    decoder needs to use the updated `encoder_sequence_id`.

    Args:
      inputs: an int32 Tensor with shape [<batch_dims>, length_dim]
      targets: an optional int32 Tensor with shape [<batch_dims>, length_dim]
      compute_loss: a boolean
      mode: a tf.estimator.ModeKeys
      variable_dtype: a mtf.VariableDType
      encoder_sequence_id: an optional Tensor
      decoder_sequence_id: an optional Tensor
      decoder_subsequence_id: an optional Tensor
      encoder_position: an optional Tensor
      decoder_position: an optional Tensor
      num_microbatches: integer - greater than one if the step has been
        serialized into multiple microbatches to save memory.

    Returns:
      logits: a Tensor with shape [<batch_dims>, output_vocab_dim]
      loss: an optional Scalar (if compute_loss=True)
    """
        # encoder_sequene_id and decoder_sequence_id are used to delineate packed
        # examples but are also necessary to indicate padding where sequence_id==0.
        # If they are absent, then we assume that padding is indicated by zeros in
        # the inputs/targets, and we make up sequence_id tensors to indicate this.
        if encoder_sequence_id is None:
            encoder_sequence_id = mtf.minimum(inputs, 1)
        if decoder_sequence_id is None:
            decoder_sequence_id = mtf.minimum(targets, 1)
        encoder_layer_outputs = []
        shared_params = self._shared_params(inputs.mesh, variable_dtype)
        encoder_output, encoder_loss = self.encoder.call_simple(
            inputs,
            None,
            compute_loss,
            mode=mode,
            variable_dtype=variable_dtype,
            sequence_id=encoder_sequence_id,
            position=encoder_position,
            shared_params=shared_params,
            layer_outputs=encoder_layer_outputs,
            num_microbatches=num_microbatches)
        encoder_output = mtf.layers.rename_length_to_memory_length(
            encoder_output)

        # The sequence_id is updated inside the layer_stack due to pooling. So we
        # need to use the updated sequence_id stored in the context.
        encoder_sequence_id = self.encoder.layer_stack.context.sequence_id
        encoder_sequence_id = mtf.layers.rename_length_to_memory_length(
            encoder_sequence_id)

        logits, loss = self.decoder.call_simple(
            transformer.autoregressive_inputs(targets,
                                              sequence_id=decoder_sequence_id),
            targets,
            compute_loss,
            mode=mode,
            variable_dtype=variable_dtype,
            sequence_id=decoder_sequence_id,
            subsequence_id=decoder_subsequence_id,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            encoder_inputs=mtf.layers.rename_length_to_memory_length(inputs),
            position=decoder_position,
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            num_microbatches=num_microbatches)
        if loss is not None and encoder_loss is not None:
            loss += encoder_loss
        return logits, loss
Exemple #19
0
    def decode(self,
               inputs,
               variable_dtype=mtf.VariableDType(tf.float32),
               beam_size=1,
               alpha=0.6,
               temperature=0.0,
               sampling_keep_top_k=-1,
               decode_length_multiplier=1.5,
               decode_length_constant=10,
               max_decode_length=None):
        """Sampling or beam search for Funnel Transformer.

    Args:
      inputs: a Tensor with shape [<batch_dims>, beam_dim, length_dim]
      variable_dtype: a mtf.VariableDType
      beam_size: an integer >= 1
      alpha: a floating point value (length bonus for beam search)
      temperature: a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      sampling_keep_top_k: a value between 1 and vocab_size used to sample from
        only the k most likely logits. Set to -1 to sample from all logits.
      decode_length_multiplier: a float
      decode_length_constant: a float
      max_decode_length: an optional integer

    Returns:
      a Tensor with shape [<batch_dims>, beam_dim, length_dim]
    """
        encoder_layer_outputs = []
        shared_params = self._shared_params(inputs.mesh, variable_dtype)
        encoder_sequence_id = mtf.minimum(inputs, 1)
        encoder_output, encoder_loss = self.encoder.call_simple(
            inputs=inputs,
            targets=None,
            compute_loss=False,
            mode=tf.estimator.ModeKeys.PREDICT,
            variable_dtype=variable_dtype,
            sequence_id=encoder_sequence_id,
            shared_params=shared_params,
            layer_outputs=encoder_layer_outputs)
        del encoder_loss
        encoder_output = mtf.layers.rename_length_to_memory_length(
            encoder_output)

        # The sequence_id is updated inside the layer_stack due to pooling. So we
        # need to use the updated sequence_id stored in the context.
        encoder_sequence_id = self.encoder.layer_stack.context.sequence_id
        encoder_sequence_id = mtf.layers.rename_length_to_memory_length(
            encoder_sequence_id)
        batch_dims = inputs.shape[:-1]
        length_dim = inputs.shape[-1]
        if max_decode_length is None:
            decode_length_dim = length_dim
        else:
            decode_length_dim = mtf.Dimension("length", max_decode_length)
        if beam_size == 1:
            ids_shape = mtf.Shape(batch_dims + [decode_length_dim])
            partial_sequences = mtf.zeros(inputs.mesh,
                                          ids_shape,
                                          dtype=tf.int32)
            return self.decoder.sample_autoregressive(
                partial_sequences,
                temperature=temperature,
                sampling_keep_top_k=sampling_keep_top_k,
                variable_dtype=variable_dtype,
                encoder_output=encoder_output,
                encoder_sequence_id=encoder_sequence_id,
                encoder_inputs=mtf.layers.rename_length_to_memory_length(
                    inputs),
                shared_params=shared_params,
                has_partial_sequences=False,
                encoder_layer_outputs=encoder_layer_outputs)
        else:
            if temperature != 0:
                raise ValueError(
                    "don't know how to beam search with nonzero temperature")
            if sampling_keep_top_k != -1:
                raise ValueError(
                    "don't know how to beam search with top-k value other than -1."
                )
            # beam search
            beam_dim = mtf.Dimension("beam", beam_size)
            ids_shape = mtf.Shape(batch_dims + [beam_dim, decode_length_dim])
            partial_sequences = mtf.zeros(inputs.mesh,
                                          ids_shape,
                                          dtype=tf.int32)
            input_length = mtf.reduce_sum(mtf.to_float(
                mtf.cast(inputs, tf.bool)),
                                          reduced_dim=length_dim)
            max_input_length = mtf.reduce_max(input_length)
            decode_length = mtf.cast(
                max_input_length * decode_length_multiplier +
                decode_length_constant, tf.int32)
            return self.decoder.beam_search(
                partial_sequences,
                decode_length,
                variable_dtype=variable_dtype,
                encoder_output=encoder_output,
                encoder_sequence_id=encoder_sequence_id,
                encoder_inputs=inputs,
                alpha=alpha,
                shared_params=shared_params,
                encoder_layer_outputs=encoder_layer_outputs)
    def beam_search(self,
                    inputs,
                    decode_length,
                    dst_attributes=None,
                    variable_dtype=mtf.VariableDType(tf.float32),
                    encoder_output=None,
                    encoder_sequence_id=None,
                    encoder_inputs=None,
                    alpha=0.6,
                    shared_params=None,
                    encoder_layer_outputs=None,
                    z=None):
        """Beam search.
        Args:
          inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim,
            length_dim].#
          decode_length: an int32 mtf scalar.  Maximum decode length.
          attributes: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim]
                                          ([<batch_dims>]
                                           [<batch_dims>, beam_dim]).
          variable_dtype: a mtf.VariableDType
          encoder_output: an optional Tensor
          encoder_sequence_id: an optional Tensor
          encoder_inputs: an optional Tensor
          alpha: a floating point value (length bonus)
          shared_params: an optional dictionary
          encoder_layer_outputs: optional - readonly list of tensor activations when
            decoding, one per each input layer + the embedding layer
        Returns:
          a Tensor with shape [<batch_dims>, beam_dim, length_dim]
        """
        attributes = dst_attributes
        if not self.autoregressive:
            raise ValueError("must be autoregressive")

        batch_dims = inputs.shape.dims[:-2]
        if len(batch_dims) != 1:
            raise NotImplementedError(
                "beam search supports exactly one batch dimension.")
        beam_dim = inputs.shape.dims[-2]
        length_dim = inputs.shape.dims[-1]
        length_range = mtf.range(inputs.mesh, length_dim, tf.int32)
        initial_position = mtf.reduce_sum(mtf.to_int32(mtf.not_equal(
            inputs, 0)),
                                          reduced_dim=length_dim)
        sequence_id = 1 if encoder_sequence_id is not None else None

        if self.input_full_attention:
            # This only makes sense in the case of beam search with given partial
            # sequences, which is not yet implemented.
            # TODO(noam): implement
            raise NotImplementedError(
                "Beam search for language models not yet implemented")
        else:
            read_priority = write_priority = length_range

        context_first_part = Context(
            model=self,
            mesh=inputs.mesh,
            batch_dims=batch_dims + [beam_dim],
            length_dim=length_dim,
            variable_dtype=variable_dtype,
            mode="first_part",
            position=length_range,
            position_is_default=True,
            new_states=[],
            initial_position=initial_position,
            sequence_id=sequence_id,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            constant_states=[],
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            write_priority=write_priority,
            read_priority=read_priority,
            inputs=inputs,
            encoder_inputs=encoder_inputs)

        shifted_inputs = mtf.shift(inputs,
                                   offset=1,
                                   dim=length_dim,
                                   wrap=False)
        with tf.variable_scope(self.name):
            logits = self._call_internal(context_first_part,
                                         shifted_inputs,
                                         attributes=attributes,
                                         z=z)
        del logits
        # There are no partial targets.
        # Replace initial states by zeros to avoid computing them.
        initial_states = [
            mtf.zeros_like(t) for t in context_first_part.new_states
        ]
        constant_states = context_first_part.constant_states

        def logits_fn(step_num, ids, states):
            """logits_fn for mtf.beam_search.beam_search()."""
            inputs_this_step = mtf.gather(ids, step_num - 1, length_dim)

            if self.attribute_embedding:
                attributes_this_step = mtf.gather(attributes, step_num - 1,
                                                  length_dim)
            else:
                attributes_this_step = None

            context_incremental = Context(
                model=self,
                mesh=inputs.mesh,
                batch_dims=batch_dims + [beam_dim],
                length_dim=length_dim,
                variable_dtype=variable_dtype,
                mode="incremental",
                position=step_num,
                states=states,
                new_states=[],
                sequence_id=sequence_id,
                encoder_output=encoder_output,
                encoder_sequence_id=encoder_sequence_id,
                constant_states=constant_states,
                shared_params=shared_params,
                encoder_layer_outputs=encoder_layer_outputs,
                write_priority=write_priority,
                read_priority=step_num,
                inputs=inputs_this_step,
                encoder_inputs=encoder_inputs)
            with tf.variable_scope(self.name, reuse=True):
                logits = self._call_internal(context_incremental,
                                             inputs_this_step,
                                             attributes=attributes_this_step,
                                             z=z)
            return mtf.to_float(logits), context_incremental.new_states

        beams, unused_scores = mtf.beam_search.beam_search(
            logits_fn,
            inputs,
            alpha,
            states=initial_states,
            decode_length=decode_length,
            use_tpu=True,
            dtype=tf.float32,
            mesh_shape=self.mesh_shape,
            layout=self.layout)
        return mtf.gather(beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32),
                          beam_dim)
    def sample_autoregressive(self,
                              partial_sequences,
                              dst_attributes=None,
                              stop_at_token=1,
                              max_steps=None,
                              temperature=0.0,
                              variable_dtype=mtf.VariableDType(tf.float32),
                              encoder_output=None,
                              encoder_sequence_id=None,
                              encoder_inputs=None,
                              shared_params=None,
                              has_partial_sequences=True,
                              encoder_layer_outputs=None,
                              never_end=False,
                              remove_partial_sequences=False,
                              sampling_keep_top_k=-1,
                              z=None):
        """Sample randomly one token at a time.
        The partial_sequences represent partial sequences to be continued.  The
        first tokens of each sequence are nonzero representing the given partial
        sequences and the last tokens of each sequence are zeros, representing what
        needs to be filled in.
        If there are no partial sequences (you want to sample from the beginning),
        then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and
        has_partial_sequences=False (so we can skip computation).
        The dst_attributes represents the destination attributes in which we want to generate sequences.
        Args:
          partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim]
          dst_attribute: an int32 Tensor with shape [<batch_dims>, length_dim] ([<batch_dims>])
          stop_at_token: an optional integer eos id.  Stop when we produce it.
          max_steps: an optional integer, the max number of steps to decode.
          temperature: an optional floating point value between 0.0 and 1.0 0.0
            means argmax, 1.0 means sample according to predicted distribution.
          variable_dtype: a mtf.VariableDType
          encoder_output: an optional Tensor
          encoder_sequence_id: an optional Tensor
          encoder_inputs: an optional Tensor
          shared_params: an optional dictionary
          has_partial_sequences: a boolean
          encoder_layer_outputs: optional - readonly list of tensor activations when
            decoding, one per each input layer + the embedding layer
          never_end: a boolean - if set, then avoid generating stop_at_token
          remove_partial_sequences: a boolean - whether to remove the partial
            sequences from the output
          sampling_keep_top_k: an integer - if not -1, only sample from the top k
            logits.
        Returns:
          a Tensor with shape [<batch_dims>, length_dim]
        """
        if not self.autoregressive:
            raise ValueError("must be autoregressive")

        inputs = partial_sequences
        attributes = dst_attributes
        batch_dims = inputs.shape.dims[:-1]
        length_dim = inputs.shape.dims[-1]
        initial_position = mtf.reduce_sum(mtf.to_int32(mtf.not_equal(
            inputs, 0)),
                                          reduced_dim=length_dim)
        sequence_id = 1 if encoder_sequence_id is not None else None

        length_range = mtf.range(inputs.mesh, length_dim, tf.int32)
        if self.input_full_attention:
            read_priority = write_priority = length_range * mtf.to_int32(
                mtf.greater(length_range, initial_position))
        else:
            read_priority = write_priority = length_range

        context_first_part = Context(
            model=self,
            mesh=inputs.mesh,
            batch_dims=batch_dims,
            length_dim=length_dim,
            variable_dtype=variable_dtype,
            mode="first_part",
            position=length_range,
            position_is_default=True,
            new_states=[],
            initial_position=initial_position,
            sequence_id=sequence_id,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            constant_states=[],
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            write_priority=write_priority,
            read_priority=read_priority,
            inputs=inputs,
            encoder_inputs=encoder_inputs)

        shifted_inputs = mtf.shift(inputs,
                                   offset=1,
                                   dim=length_dim,
                                   wrap=False)
        with tf.variable_scope(self.name):
            logits = self._call_internal(context_first_part,
                                         shifted_inputs,
                                         attributes=attributes,
                                         z=z)
        del logits
        constant_states = context_first_part.constant_states
        if not has_partial_sequences:
            initial_states = [
                mtf.zeros_like(t) for t in context_first_part.new_states
            ]
            partial_sequences_eos_count = 0
        else:
            initial_states = context_first_part.new_states
            partial_sequences_eos_count = mtf.reduce_sum(
                mtf.to_int32(mtf.equal(partial_sequences, stop_at_token)),
                reduced_dim=length_dim)

        def cond_fn(position, ids, *unused_states):
            """Should we run another loop iteration."""
            past_end = mtf.greater_equal(position, length_dim.size)
            if max_steps:
                past_end = mtf.logical_or(
                    past_end,
                    mtf.greater_equal(position - initial_position, max_steps))

            is_done = past_end
            if stop_at_token is not None:
                eos_count = mtf.reduce_sum(mtf.to_int32(
                    mtf.equal(ids, stop_at_token)),
                                           reduced_dim=length_dim)
                has_additional_eos = mtf.greater(eos_count,
                                                 partial_sequences_eos_count)
                is_done = mtf.logical_or(is_done, has_additional_eos)
            all_done = mtf.reduce_all(is_done)
            return mtf.logical_not(all_done)

        def body_fn(position, ids, *states):
            """One step in the decode loop."""
            inputs_this_step = mtf.gather(ids, position - 1, length_dim)
            if self.attribute_embedding:
                attributes_this_step = mtf.gather(attributes, position - 1,
                                                  length_dim)
            else:
                attributes_this_step = None
            # raise ValueError("inputs_this_step shape=%s , ids shape=%s, position - 1 shape=%s, length_dim=%s" % (inputs_this_step.shape, ids.shape, (position - 1).shape, length_dim))
            context_incremental = Context(
                model=self,
                mesh=inputs.mesh,
                batch_dims=batch_dims,
                length_dim=length_dim,
                variable_dtype=variable_dtype,
                mode="incremental",
                position=position,
                states=states,
                new_states=[],
                sequence_id=sequence_id,
                encoder_output=encoder_output,
                encoder_sequence_id=encoder_sequence_id,
                constant_states=constant_states,
                shared_params=shared_params,
                encoder_layer_outputs=encoder_layer_outputs,
                write_priority=write_priority,
                read_priority=position,
                inputs=inputs_this_step,
                encoder_inputs=encoder_inputs)

            with tf.variable_scope(self.name, reuse=True):
                logits = self._call_internal(context_incremental,
                                             inputs_this_step,
                                             attributes=attributes_this_step,
                                             z=z)
                if never_end:
                    logits += mtf.one_hot(mtf.constant(logits.mesh,
                                                       stop_at_token,
                                                       dtype=tf.int32),
                                          self.output_vocab_dim,
                                          on_value=-1e9,
                                          off_value=0.0,
                                          dtype=logits.dtype)

            # TBD whether this should be before or after never_end:
            # Note for adding top_p sampling in the future, in other code bases, the
            # option to apply temperature is done before the top-k truncation. This
            # implementation does this in the opposite order. For top-k this doesn't
            # matter, but for top_p it will.
            if sampling_keep_top_k != -1:
                if sampling_keep_top_k <= 0:
                    raise ValueError(
                        "sampling_keep_top_k must either be -1 or positive.")
                k_largest = mtf.nth_largest_element(
                    logits,
                    n=sampling_keep_top_k,
                    reduced_dim=self.output_vocab_dim)
                logits = mtf.where(mtf.less_equal(logits, k_largest),
                                   mtf.ones_like(logits) * -1e6, logits)

            ids_this_step = mtf.sample_with_temperature(
                logits, self.output_vocab_dim, temperature)
            new_position = position + 1
            new_ids = ids + ids_this_step * mtf.one_hot(
                position, length_dim, dtype=tf.int32)
            return [new_position, new_ids] + context_incremental.new_states

        while_loop_inputs = [initial_position, inputs] + initial_states
        final_position, outputs = mtf.while_loop(cond_fn, body_fn,
                                                 while_loop_inputs)[:2]
        del final_position
        if has_partial_sequences and remove_partial_sequences:
            # remove partial sequences from outputs
            partial_length = mtf.reduce_sum(mtf.to_int32(
                mtf.not_equal(partial_sequences, 0)),
                                            reduced_dim=length_dim)
            outputs = mtf.dynamic_shift(outputs,
                                        -partial_length,
                                        length_dim,
                                        wrap=False)
        return outputs
Exemple #22
0
def create_dummy_model(mesh,
                       shapes,
                       n_blocks=2,
                       block_param_size_str="2_2",
                       block_repeat_size_str="1_1"):
    """Creates a dummy model and layer stack with 4-dimensional input."""

    assert len(shapes) == 4
    outer_batch_size, batch_size, length, d_model = shapes
    batch_dim = mtf.Dimension("batch", batch_size)
    outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
    length_dim = mtf.Dimension("length", length)
    block_param_size = list(map(int, block_param_size_str.split("_")))
    block_repeat_size = list(map(int, block_repeat_size_str.split("_")))

    sublayers_initial = [
        transformer.sublayer_dropout,
    ]
    sublayers_per_layer = [
        transformer.sublayer_rms_norm,
        transformer.sublayer_call_layer,
        transformer.sublayer_dropout,
        transformer.sublayer_residual,
    ]
    sublayers_final = [
        transformer.sublayer_rms_norm,
        transformer.sublayer_dropout,
    ]
    submodules = [
        transformer_layers.SelfAttention(),
        transformer_layers.DenseReluDense()
    ]

    n_sublayers = np.array(block_param_size).prod()
    layers = submodules * n_sublayers
    layer_stack = funnel_transformer.FunnelTransformerLayerStack(
        layers=layers,
        n_blocks=n_blocks,
        block_param_size=block_param_size,
        block_repeat_size=block_repeat_size,
        sublayers_initial=sublayers_initial,
        sublayers_per_layer=sublayers_per_layer,
        sublayers_final=sublayers_final)

    model = transformer.Unitransformer(input_vocab_size=10,
                                       output_vocab_size=10,
                                       autoregressive=False,
                                       max_length=8,
                                       d_model=d_model,
                                       layer_stack=layer_stack)

    context = transformer.Context(model=model,
                                  mesh=mesh,
                                  batch_dims=[batch_dim, outer_batch_dim],
                                  length_dim=length_dim,
                                  variable_dtype=mtf.VariableDType(tf.float32),
                                  sequence_id=mtf.ones(mesh,
                                                       mtf.Shape([length_dim
                                                                  ])),
                                  position=mtf.range(mesh,
                                                     length_dim,
                                                     dtype=tf.int32))
    return layer_stack, context
    def call_simple(self,
                    inputs,
                    targets,
                    compute_loss,
                    attributes=None,
                    codeprefixedtargets=None,
                    mode=tf.estimator.ModeKeys.TRAIN,
                    variable_dtype=mtf.VariableDType(tf.float32),
                    encoder_sequence_id=None,
                    decoder_sequence_id=None,
                    decoder_subsequence_id=None,
                    encoder_position=None,
                    decoder_position=None):  # attributes=None for debugging?
        """Compute logits based on inputs (all positions in parallel).
        This is called during training and evaluation.
        Args:
          inputs: an int32 Tensor with shape [<batch_dims>, length_dim]
          targets: an optional int32 Tensor with shape [<batch_dims>, length_dim]
          compute_loss: a boolean
          attributes: an (optional?) int32 Tensor with shape [<batch_dims>, length_dim] ([<batch_dims>])
          mode: a tf.estimator.ModeKeys
          variable_dtype: a mtf.VariableDType
          encoder_sequence_id: an optional Tensor
          decoder_sequence_id: an optional Tensor
          decoder_subsequence_id: an optional Tensor
          encoder_position: an optional Tensor
          decoder_position: an optional Tensor
        Returns:
          logits: a Tensor with shape [<batch_dims>, output_vocab_dim]
          loss: an optional Scalar (if compute_loss=True)
        """
        encoder_layer_outputs = []
        shared_params = self._shared_params(inputs.mesh, variable_dtype)
        encoder_output, encoder_loss = self.encoder.call_simple(
            inputs,
            None,
            compute_loss,
            attributes=attributes,
            mode=mode,
            variable_dtype=variable_dtype,
            sequence_id=encoder_sequence_id,
            position=encoder_position,
            shared_params=shared_params,
            layer_outputs=encoder_layer_outputs)
        encoder_output = mtf.layers.rename_length_to_memory_length(
            encoder_output)
        if encoder_sequence_id is not None:
            encoder_sequence_id = mtf.layers.rename_length_to_memory_length(
                encoder_sequence_id)

        if self.cut_cross_attention:
            z = mtf.gather(
                encoder_output,
                mtf.zeros(inputs.mesh,
                          mtf.Shape(inputs.shape[:-1] +
                                    [encoder_output.shape[-1]]),
                          dtype=tf.int32), encoder_output.shape[-2])
            encoder_output = None
        else:
            z = None

        if codeprefixedtargets:
            decoder_input = shift_targets_no_offset(
                codeprefixedtargets
            )  # shift_attribute_targets(targets, attribute_id=codeprefixedtargets), # codeprefixedtargets # mtf.zeros_like(targets)
        else:
            decoder_input = shift_targets(targets)

        # shift_targets = mtf.shift(targets, offset=-1, dim=targets.shape.dims[-1], wrap=False) # Remove token preceding ":"
        # shift_targets = mtf.shift(shift_targets, offset=1, dim=targets.shape.dims[-1], wrap=False)
        logits, loss = self.decoder.call_simple(
            decoder_input,
            targets,
            compute_loss,
            attributes=attributes,
            mode=mode,
            variable_dtype=variable_dtype,
            sequence_id=decoder_sequence_id,
            subsequence_id=decoder_subsequence_id,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            encoder_inputs=mtf.layers.rename_length_to_memory_length(inputs),
            position=decoder_position,
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            z=z)  # Maybe sample_autoregressive here ?

        if loss is not None and encoder_loss is not None:
            loss += encoder_loss
        return logits, loss
Exemple #24
0
def run(tpu_job_name,
        data_dir,
        master_dtype,
        slice_dtype,
        activation_dtype,
        tpu,
        gcp_project,
        tpu_zone,
        autostack,
        model_dir,
        mode=gin.REQUIRED,
        iterations_per_loop=gin.REQUIRED,
        save_checkpoints_steps=gin.REQUIRED,
        eval_steps=gin.REQUIRED,
        train_steps=gin.REQUIRED,
        batch_size=gin.REQUIRED,
        text2self=gin.REQUIRED,
        dataset=gin.REQUIRED):
    """Run training/eval/inference.

  Args:
    tpu_job_name: string, name of TPU worker binary
    data_dir: string, data_dir for TensorFlow Datasets
    master_dtype: string, datatype for checkpoints
    slice_dtype: string, datatype for variables in memory
    activation_dtype: string, datatype for activations
    tpu: string, the Cloud TPU to use for training
    gcp_project: string, project name for the Cloud TPU-enabled project
    tpu_zone: string, GCE zone where the Cloud TPU is located in
    autostack: boolean, internally combine variables
    model_dir: string, estimator model_dir
    mode: string, train/evaluate/infer
    iterations_per_loop: integer, steps per train loop
    save_checkpoints_steps: integer, steps per checkpoint
    eval_steps: integer, number of evaluation steps
    train_steps: Total number of training steps.
    batch_size: Mini-batch size for the training. Note that this is the global
      batch size and not the per-shard batch.
    text2self: Whether to train a language model (True) or encoder-decoder
      text-to-text model (False).
    dataset: TensorFlow Datasets dataset name.
  """
    cluster = tf.contrib.cluster_resolver.TPUClusterResolver(
        tpu if (tpu) else "", zone=tpu_zone, project=gcp_project)

    my_tpu_config = tpu_config.TPUConfig(
        tpu_job_name=tpu_job_name,
        iterations_per_loop=iterations_per_loop,
        num_cores_per_replica=1,
        per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST,
    )

    run_config = tpu_config.RunConfig(
        cluster=cluster,
        model_dir=model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        tpu_config=my_tpu_config)

    dataset = transformer_dataset.TokenizedTFDSDataset(dataset,
                                                       text2self=text2self,
                                                       data_dir=data_dir
                                                       or None)

    output_encoder = dataset.encoders["targets"]
    if text2self:
        input_encoder = output_encoder
    else:
        input_encoder = dataset.encoders["inputs"]

    transformer_model = model(
        input_vocab_size=transformer_dataset.padded_vocab_size(
            input_encoder.vocab_size, 128),
        output_vocab_size=transformer_dataset.padded_vocab_size(
            output_encoder.vocab_size, 128),
        text2self=text2self)
    mesh_shape = mtf.convert_to_shape(gin.query_parameter("model.mesh_shape"))
    layout_rules = mtf.convert_to_layout_rules(
        gin.query_parameter("model.layout"))
    # Data-types used for variables and activations
    # See comments in the FLAGS
    master_dtype = tf.as_dtype(master_dtype)
    if slice_dtype:
        slice_dtype = tf.as_dtype(slice_dtype)
    elif not tpu or mode == "train":
        slice_dtype = tf.float32
    else:
        slice_dtype = tf.bfloat16
    if activation_dtype:
        activation_dtype = tf.as_dtype(activation_dtype)
    else:
        activation_dtype = tf.bfloat16 if tpu else tf.float32
    variable_dtype = mtf.VariableDType(master_dtype=master_dtype,
                                       slice_dtype=slice_dtype,
                                       activation_dtype=activation_dtype)

    length_from_config = gin.query_parameter(
        "model.length") or gin.query_parameter("model.max_length")

    model_fn = tpu_estimator_model_fn(transformer_model=transformer_model,
                                      model_dir=model_dir,
                                      use_tpu=tpu,
                                      mesh_shape=mesh_shape,
                                      layout_rules=layout_rules,
                                      text2self=text2self,
                                      variable_dtype=variable_dtype,
                                      batch_size=batch_size,
                                      length=length_from_config,
                                      autostack=autostack)

    estimator = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                           config=run_config,
                                           train_batch_size=batch_size,
                                           eval_batch_size=batch_size,
                                           predict_batch_size=batch_size,
                                           use_tpu=tpu,
                                           export_to_tpu=False,
                                           params={})

    def input_fn(params):
        del params
        return dataset.load(batch_size=batch_size,
                            length=length_from_config,
                            train=(mode == "train"),
                            pack=True)

    if mode == "train":
        estimator.train(input_fn=input_fn, max_steps=train_steps)
    elif mode == "evaluate":
        estimator.evaluate(
            input_fn=input_fn,
            steps=eval_steps,
        )
    elif mode == "infer":
        decode_from_file(estimator,
                         batch_size=batch_size,
                         length=length_from_config,
                         inputs_encoder=dataset.
                         encoders["targets" if text2self else "inputs"],
                         targets_encoder=dataset.encoders["targets"],
                         text2self=text2self)
    else:
        raise ValueError("unknown mode %s - must be train/evaluate/infer" %
                         mode)
Exemple #25
0
    id_hldr = mtf.import_tf_tensor(
        mesh,
        np.random.randn(batch_dim.size, field_dim.size).astype(np.int32),
        shape=[batch_dim, field_dim])
    wt_hldr = mtf.import_tf_tensor(
        mesh,
        np.random.randn(batch_dim.size, field_dim.size).astype(np.float32),
        shape=[batch_dim, field_dim])
    label = mtf.import_tf_tensor(mesh,
                                 np.array([1, 0, 1, 0, 1, 0, 1, 0, 1,
                                           0]).astype(np.float32),
                                 shape=[batch_dim])

    fp16 = True
    if fp16:
        float16 = mtf.VariableDType(tf.float16, tf.float16, tf.float16)
        id_hldr = mtf.cast(id_hldr, dtype=tf.int32)
        wt_hldr = mtf.cast(wt_hldr, dtype=tf.float16)
    else:
        float16 = None
    result, embedding_table = widedeep(id_hldr, wt_hldr, vocab_dim, embed_dim,
                                       outdim, float16)

    # label = mtf.reshape(label,new_shape=[batch_dim, outdim])
    # output = mtf.layers.softmax_cross_entropy_with_logits(result, label,vocab_dim=outdim)
    # result = mtf.sigmoid(result)
    # result = -(label*mtf.log(result)+(1-label)*mtf.log(1-result))
    # result = mtf.reduce_sum(result)
    result = mtf.cast(result, dtype=tf.float32)
    embedding_table = mtf.cast(embedding_table, dtype=tf.float32)
    probability = mtf.sigmoid(result)
Exemple #26
0
  def beam_search(self,
                  inputs,
                  decode_length,
                  variable_dtype=mtf.VariableDType(tf.float32),
                  encoder_output=None,
                  encoder_sequence_id=None,
                  alpha=0.6,
                  shared_params=None,
                  encoder_layer_outputs=None):
    """Beam search.

    Args:
      inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim,
        length_dim].
      decode_length: an int32 mtf scalar.  Maximum decode length.
      variable_dtype: a mtf.VariableDType
      encoder_output: an optional Tensor
      encoder_sequence_id: an optional Tensor
      alpha: a floating point value (length bonus)
      shared_params: an optional dictionary
      encoder_layer_outputs: optional - readonly list of tensor activations when
        decoding, one per each input layer + the embedding layer

    Returns:
      a Tensor with shape [<batch_dims>, beam_dim, length_dim]
    """
    if not self.autoregressive:
      raise ValueError("must be autoregressive")

    batch_dims = inputs.shape.dims[:-2]
    if len(batch_dims) != 1:
      raise NotImplementedError(
          "beam search supports exactly one batch dimension.")
    beam_dim = inputs.shape.dims[-2]
    length_dim = inputs.shape.dims[-1]
    initial_position = mtf.reduce_sum(
        mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim)
    sequence_id = 1 if encoder_sequence_id is not None else None

    context_first_part = Context(
        mesh=inputs.mesh,
        batch_dims=batch_dims + [beam_dim],
        length_dim=length_dim,
        model_dim=self.model_dim,
        variable_dtype=variable_dtype,
        mode="first_part",
        autoregressive=self.autoregressive,
        new_states=[],
        initial_position=initial_position,
        sequence_id=sequence_id,
        encoder_output=encoder_output,
        encoder_sequence_id=encoder_sequence_id,
        constant_states=[],
        shared_params=shared_params,
        layout=self.layout,
        mesh_shape=self.mesh_shape,
        encoder_layer_outputs=encoder_layer_outputs)

    shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False)
    with tf.variable_scope(self.name):
      logits = self._call_internal(context_first_part, shifted_inputs)
    del logits
    # There are no partial targets.
    # Replace initial states by zeros to avoid computing them.
    initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states]
    constant_states = context_first_part.constant_states

    def logits_fn(step_num, ids, states):
      """logits_fn for mtf.beam_search.beam_search()."""
      context_incremental = Context(
          mesh=inputs.mesh,
          batch_dims=batch_dims + [beam_dim],
          length_dim=length_dim,
          model_dim=self.model_dim,
          variable_dtype=variable_dtype,
          mode="incremental",
          autoregressive=self.autoregressive,
          position=step_num,
          states=states,
          new_states=[],
          sequence_id=sequence_id,
          encoder_output=encoder_output,
          encoder_sequence_id=encoder_sequence_id,
          constant_states=constant_states,
          shared_params=shared_params,
          layout=self.layout,
          mesh_shape=self.mesh_shape,
          encoder_layer_outputs=encoder_layer_outputs)
      inputs_this_step = mtf.gather(ids, step_num - 1, length_dim)
      with tf.variable_scope(self.name, reuse=True):
        logits = self._call_internal(context_incremental, inputs_this_step)
      return mtf.to_float(logits), context_incremental.new_states

    beams, unused_scores = mtf.beam_search.beam_search(
        logits_fn,
        inputs,
        alpha,
        states=initial_states,
        decode_length=decode_length,
        use_tpu=True,
        dtype=tf.float32,
        mesh_shape=self.mesh_shape,
        layout=self.layout)
    return mtf.gather(
        beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32), beam_dim)
Exemple #27
0
  def sample_autoregressive(self,
                            partial_sequences,
                            stop_at_token=1,
                            max_steps=None,
                            temperature=1.0,
                            variable_dtype=mtf.VariableDType(tf.float32),
                            encoder_output=None,
                            encoder_sequence_id=None,
                            shared_params=None,
                            has_partial_sequences=True,
                            encoder_layer_outputs=None):
    """Sample randomly one token at a time.

    The partial_sequences represent partial sequences to be continued.  The
    first tokens of each sequence are nonzero representing the given partial
    sequences and the last tokens of each sequence are zeros, representing what
    needs to be filled in.

    If there are no partial sequences (you want to sample from the beginning),
    then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and
    has_partial_sequences=False (so we can skip computation).

    Args:
      partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim]
      stop_at_token: an optional integer eos id.  Stop when we produce it.
      max_steps: an optional integer
      temperature: an optional floating point value between 0.0 and 1.0 0.0
        means argmax, 1.0 means sample according to predicted distribution.
      variable_dtype: a mtf.VariableDType
      encoder_output: an optional Tensor
      encoder_sequence_id: an optional Tensor
      shared_params: an optional dictionary
      has_partial_sequences: a boolean
      encoder_layer_outputs: optional - readonly list of tensor activations when
        decoding, one per each input layer + the embedding layer

    Returns:
      a Tensor with shape [<batch_dims>, length_dim]
    """
    del max_steps  # TODO(noam): implement
    if not self.autoregressive:
      raise ValueError("must be autoregressive")

    inputs = partial_sequences
    batch_dims = inputs.shape.dims[:-1]
    length_dim = inputs.shape.dims[-1]
    initial_position = mtf.reduce_sum(
        mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim)
    sequence_id = 1 if encoder_sequence_id is not None else None

    context_first_part = Context(
        mesh=inputs.mesh,
        batch_dims=batch_dims,
        length_dim=length_dim,
        model_dim=self.model_dim,
        variable_dtype=variable_dtype,
        mode="first_part",
        autoregressive=self.autoregressive,
        new_states=[],
        initial_position=initial_position,
        sequence_id=sequence_id,
        encoder_output=encoder_output,
        encoder_sequence_id=encoder_sequence_id,
        constant_states=[],
        shared_params=shared_params,
        layout=self.layout,
        mesh_shape=self.mesh_shape,
        encoder_layer_outputs=encoder_layer_outputs)

    shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False)
    with tf.variable_scope(self.name):
      logits = self._call_internal(context_first_part, shifted_inputs)
    del logits
    constant_states = context_first_part.constant_states
    if not has_partial_sequences:
      initial_states = [
          mtf.zeros_like(t) for t in context_first_part.new_states]
    else:
      initial_states = context_first_part.new_states

    def cond_fn(position, ids, *unused_states):
      """Should we run another loop iteration."""
      past_end = mtf.greater_equal(position, length_dim.size)
      is_done = past_end
      if stop_at_token is not None:
        has_eos = mtf.reduce_any(
            mtf.equal(ids, stop_at_token), reduced_dim=length_dim)
        is_done = mtf.logical_or(is_done, has_eos)
      all_done = mtf.reduce_all(is_done)
      return mtf.logical_not(all_done)

    def body_fn(position, ids, *states):
      """One step in the decode loop."""
      context_incremental = Context(
          mesh=inputs.mesh,
          batch_dims=batch_dims,
          length_dim=length_dim,
          model_dim=self.model_dim,
          variable_dtype=variable_dtype,
          mode="incremental",
          autoregressive=self.autoregressive,
          position=position,
          states=states,
          new_states=[],
          sequence_id=sequence_id,
          encoder_output=encoder_output,
          encoder_sequence_id=encoder_sequence_id,
          constant_states=constant_states,
          shared_params=shared_params,
          layout=self.layout,
          mesh_shape=self.mesh_shape,
          encoder_layer_outputs=encoder_layer_outputs)
      inputs_this_step = mtf.gather(ids, position - 1, length_dim)
      with tf.variable_scope(self.name, reuse=True):
        logits = self._call_internal(context_incremental, inputs_this_step)
      ids_this_step = mtf.sample_with_temperature(
          logits, self.output_vocab_dim, temperature)
      new_position = position + 1
      new_ids = ids + ids_this_step * mtf.one_hot(
          position, length_dim, dtype=tf.int32)
      return [new_position, new_ids] + context_incremental.new_states
    while_loop_inputs = [initial_position, inputs] + initial_states
    final_position, outputs = mtf.while_loop(
        cond_fn, body_fn, while_loop_inputs)[:2]
    del final_position
    return outputs
Exemple #28
0
def sample_autoregressive(
    partial_sequences,
    other_features,
    params,
    stop_at_token=50256,
    max_steps=None,
    temperature=0.9,
    variable_dtype=mtf.VariableDType(tf.float32),
    encoder_output=None,
    encoder_sequence_id=None,
    encoder_inputs=None,
    shared_params=None,
    has_partial_sequences=True,
    encoder_layer_outputs=None,
    never_end=False,
    remove_partial_sequences=False,
    sampling_keep_top_k=-1,
    bos_id=50256,
):
    """Sample randomly one token at a time.

    The partial_sequences represent partial sequences to be continued.  The
    first tokens of each sequence are nonzero representing the given partial
    sequences and the last tokens of each sequence are zeros, representing what
    needs to be filled in.

    If there are no partial sequences (you want to sample from the beginning),
    then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and
    has_partial_sequences=False (so we can skip computation).

    Args:
        partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim]
        stop_at_token: an optional integer eos id.  Stop when we produce it.
        max_steps: an optional integer, the max number of steps to decode.
        temperature: an optional floating point value between 0.0 and 1.0 0.0
        means argmax, 1.0 means sample according to predicted distribution.
        variable_dtype: a mtf.VariableDType
        encoder_output: an optional Tensor
        encoder_sequence_id: an optional Tensor
        encoder_inputs: an optional Tensor
        shared_params: an optional dictionary
        has_partial_sequences: a boolean
        encoder_layer_outputs: optional - readonly list of tensor activations when
        decoding, one per each input layer + the embedding layer
        never_end: a boolean - if set, then avoid generating stop_at_token
        remove_partial_sequences: a boolean - whether to remove the partial
        sequences from the output
        sampling_keep_top_k: an integer - if not -1, only sample from the top k
        logits.
        bos_id: beginning of sequence id

    Returns:
        a Tensor with shape [<batch_dims>, length_dim]
    """

    inputs = partial_sequences  # Partial sequences to fill in
    batch_dims = inputs.shape.dims[:-1]
    length_dim = inputs.shape.dims[-1]
    padding_id = params.get("padding_id", 0)
    slow_sampling = params.get("slow_sampling", False)

    initial_position = mtf.reduce_sum(
        mtf.to_int32(mtf.not_equal(inputs, padding_id)),
        reduced_dim=length_dim)  # Gets position where zero padding starts

    length_range = mtf.range(inputs.mesh, length_dim, tf.int32)
    input_full_attention = True  # for now hardcode this to true bc lazy
    if input_full_attention:
        # Vanilla autoregressive model - each position can see previous positions.
        # Think this feeds in to the loop fn and tells each position where it can attend to?
        read_priority = write_priority = length_range * mtf.to_int32(
            mtf.greater(length_range, initial_position))
    else:
        read_priority = write_priority = length_range

    # Builds context to pass around internally
    # The 'first part' context records initial states of k / v / x

    if not slow_sampling:
        context_first_part = mtf_transformer.transformer.Context(
            model=None,
            mesh=inputs.mesh,
            batch_dims=batch_dims,
            length_dim=length_dim,
            variable_dtype=variable_dtype,
            mode="first_part",
            position=length_range,
            position_is_default=True,
            new_states=[],
            initial_position=initial_position,
            sequence_id=None,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            constant_states=[],
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            write_priority=write_priority,
            read_priority=read_priority,
            inputs=inputs,
            encoder_inputs=encoder_inputs)

        with tf.variable_scope("gpt2"):
            logits, _, _ = gpt2.model({"inputs": inputs},
                                      other_features,
                                      params,
                                      inputs.mesh,
                                      variable_dtype=variable_dtype,
                                      context=context_first_part)

        if not has_partial_sequences:
            initial_states = [
                mtf.zeros_like(t) for t in context_first_part.new_states
            ]
        else:
            initial_states = context_first_part.new_states
    else:
        initial_states = []

    if not has_partial_sequences:
        partial_sequences_eos_count = 0

    if stop_at_token is not None:
        partial_sequences_eos_count = mtf.reduce_sum(mtf.to_int32(
            mtf.equal(partial_sequences, stop_at_token)),
                                                     reduced_dim=length_dim)

    def cond_fn(position, ids, *unused_states):
        """Should we run another loop iteration?"""
        past_end = mtf.greater_equal(position, length_dim.size)
        if max_steps:
            past_end = mtf.logical_or(
                past_end,
                mtf.greater_equal(position - initial_position, max_steps))

        is_done = past_end
        if stop_at_token is not None:
            eos_count = mtf.reduce_sum(mtf.to_int32(
                mtf.equal(ids, stop_at_token)),
                                       reduced_dim=length_dim)
            has_additional_eos = mtf.greater(eos_count,
                                             partial_sequences_eos_count)
            is_done = mtf.logical_or(is_done, has_additional_eos)
        all_done = mtf.reduce_all(is_done)
        return mtf.logical_not(all_done)

    def body_fn(position, ids, *states):
        """One step in the decode loop."""
        nonlocal sampling_keep_top_k

        context = mtf_transformer.transformer.Context(
            model=None,
            mesh=inputs.mesh,
            batch_dims=batch_dims,
            length_dim=length_dim,
            variable_dtype=variable_dtype,
            mode="incremental",
            position=position,
            position_is_default=True,
            states=states,
            new_states=[],
            initial_position=position,
            sequence_id=None,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            write_priority=write_priority,
            read_priority=read_priority,
            inputs=ids,
            encoder_inputs=encoder_inputs) if not slow_sampling else None

        with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE):
            logits, _, _ = gpt2.model({"inputs": ids},
                                      other_features,
                                      params,
                                      inputs.mesh,
                                      variable_dtype=variable_dtype,
                                      context=context)

        # By default, do top_k sampling of 0.9
        if sampling_keep_top_k == -2:
            sampling_keep_top_k = int(logits.shape[-1].size * 0.1)

        if sampling_keep_top_k != -1:
            if sampling_keep_top_k <= 0:
                raise ValueError(
                    "sampling_keep_top_k must either be -1 or positive.")
            k_largest = mtf.nth_largest_element(
                logits,
                n=sampling_keep_top_k,
                reduced_dim=other_features["vocab_dim"])
            logits = mtf.where(mtf.less_equal(logits, k_largest),
                               mtf.ones_like(logits) * -1e6, logits)

        ids_this_step = mtf.sample_with_temperature(
            logits, other_features["vocab_dim"], temperature)

        if slow_sampling:
            ids_this_step = mtf.shift(ids_this_step,
                                      offset=1,
                                      dim=length_dim,
                                      wrap=False)
        else:
            ids_this_step = mtf.reshape(ids_this_step, (batch_dims))

        one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32)
        one_new_id = ids_this_step * one_hot
        new_ids = (1 - one_hot) * ids + one_new_id
        new_position = position + 1

        ret = [new_position, new_ids]
        if context is not None:
            ret += context.new_states
        return ret

    while_loop_inputs = [initial_position, inputs] + initial_states
    final_position, outputs = mtf.while_loop(cond_fn, body_fn,
                                             while_loop_inputs)[:2]
    del final_position
    if has_partial_sequences and remove_partial_sequences:
        # Remove partial sequences from outputs
        partial_length = mtf.reduce_sum(mtf.to_int32(
            mtf.not_equal(partial_sequences, padding_id)),
                                        reduced_dim=length_dim)
        outputs = mtf.dynamic_shift(outputs,
                                    -partial_length,
                                    length_dim,
                                    wrap=False)
    return outputs
Exemple #29
0
  def decode(self,
             inputs,
             variable_dtype=mtf.VariableDType(tf.float32),
             beam_size=1,
             alpha=0.6,
             temperature=0.0,
             decode_length_multiplier=1.5,
             decode_length_constant=10):
    """Sampling or beam search.

    TODO(noam): should we make the output length dimension different from the
    input length dimension?

    Args:
      inputs: a Tensor with shape [<batch_dims>, beam_dim, length_dim]
      variable_dtype: a mtf.VariableDType
      beam_size: an integer >= 1
      alpha: a floating point value (length bonus for beam search)
      temperature: a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      decode_length_multiplier: a float
      decode_length_constant: a float

    Returns:
      a Tensor with shape [<batch_dims>, beam_dim, length_dim]
    """
    encoder_layer_outputs = []
    shared_params = self._shared_params(inputs.mesh, variable_dtype)
    encoder_sequence_id = mtf.minimum(inputs, 1)
    encoder_output, encoder_loss = self.encoder.call_simple(
        inputs=inputs,
        targets=None,
        compute_loss=False,
        mode=tf.estimator.ModeKeys.PREDICT,
        variable_dtype=variable_dtype,
        sequence_id=encoder_sequence_id,
        shared_params=shared_params,
        layer_outputs=encoder_layer_outputs)
    del encoder_loss
    encoder_output = mtf.layers.rename_length_to_memory_length(encoder_output)
    encoder_sequence_id = mtf.layers.rename_length_to_memory_length(
        encoder_sequence_id)
    if beam_size == 1:
      ids_shape = inputs.shape
      partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32)
      return self.decoder.sample_autoregressive(
          partial_sequences,
          temperature=temperature,
          variable_dtype=variable_dtype,
          encoder_output=encoder_output,
          encoder_sequence_id=encoder_sequence_id,
          shared_params=shared_params,
          has_partial_sequences=False,
          encoder_layer_outputs=encoder_layer_outputs)
    else:
      if temperature != 0:
        raise ValueError(
            "don't know how to beam search with nonzero temperature")
      # beam search
      beam_dim = mtf.Dimension("beam", beam_size)
      batch_dims = inputs.shape[:-1]
      length_dim = inputs.shape[-1]
      ids_shape = mtf.Shape(batch_dims + [beam_dim, length_dim])
      partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32)
      input_length = mtf.reduce_sum(
          mtf.to_float(mtf.cast(inputs, tf.bool)),
          reduced_dim=length_dim)
      max_input_length = mtf.reduce_max(input_length)
      decode_length = mtf.cast(
          max_input_length * decode_length_multiplier
          + decode_length_constant, tf.int32)
      return self.decoder.beam_search(
          partial_sequences,
          decode_length,
          variable_dtype=variable_dtype,
          encoder_output=encoder_output,
          encoder_sequence_id=encoder_sequence_id,
          alpha=alpha,
          shared_params=shared_params,
          encoder_layer_outputs=encoder_layer_outputs)
def my_model_fn(features,
                labels,
                mode,
                params=None,
                config=None):
  """Estimator model function.

  Args:
    features: input features dictionary
    labels: ignored
    mode: a tf.estimator.ModeKeys
    params: something
    config: something
  Returns:
    something
  """
  del labels, config
  use_tpu = FLAGS.tpu
  global_step = tf.train.get_global_step()

  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
  if use_tpu:
    ctx = params["context"]
    num_hosts = ctx.num_hosts
    host_placement_fn = ctx.tpu_host_placement_function
    device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
    # TODO(ylc): Better estimation of replica cache size?
    replica_cache_size = 300 * 1000000  # 300M per replica
    # Worker 0 caches all the TPU binaries.
    worker0_mem = replica_cache_size * ctx.num_replicas
    devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
    var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                  devices_memeory_usage)
    mesh_devices = [""] * mesh_shape.size
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
        mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
  else:
    var_placer = None
    mesh_devices = [""] * mesh_shape.size
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)

  graph = mtf.Graph()
  mesh = mtf.Mesh(graph, "my_mesh", var_placer)

  model = transformer.Bitransformer(
      encoder_layer_stack=layer_stack(include_encdec_attention=False),
      decoder_layer_stack=layer_stack(include_encdec_attention=True),
      encoder_d_model=FLAGS.d_model,
      decoder_d_model=FLAGS.d_model,
      input_vocab_size=transformer_dataset.padded_vocab_size(
          transformer_dataset.inputs_vocab_size(FLAGS.dataset)),
      output_vocab_size=transformer_dataset.padded_vocab_size(
          transformer_dataset.targets_vocab_size(FLAGS.dataset)),
      max_length=FLAGS.max_length,
      shared_embedding=False,
      shared_embedding_and_softmax_weights=True,
      label_smoothing=FLAGS.label_smoothing,
      layout=FLAGS.layout,
      mesh_shape=FLAGS.mesh_shape)

  inputs = import_feature(features, mesh, "inputs")

  # Data-types used for variables and activations
  # See comments in the FLAGS
  master_dtype = tf.as_dtype(FLAGS.master_dtype)
  if FLAGS.slice_dtype:
    slice_dtype = tf.as_dtype(FLAGS.slice_dtype)
  elif not FLAGS.tpu or FLAGS.mode == "train":
    slice_dtype = tf.float32
  else:
    slice_dtype = tf.bfloat16
  if FLAGS.activation_dtype:
    activation_dtype = tf.as_dtype(FLAGS.activation_dtype)
  else:
    activation_dtype = tf.bfloat16 if FLAGS.tpu else tf.float32
  variable_dtype = mtf.VariableDType(master_dtype=master_dtype,
                                     slice_dtype=slice_dtype,
                                     activation_dtype=activation_dtype)

  # PREDICT mode
  if mode == tf.estimator.ModeKeys.PREDICT:
    mtf_samples = model.decode(
        inputs,
        variable_dtype=variable_dtype,
        beam_size=FLAGS.beam_size,
        alpha=FLAGS.alpha,
        temperature=FLAGS.temperature)
    mtf_samples = mtf.anonymize(mtf_samples)
    lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=FLAGS.autostack)
    outputs = lowering.export_to_tf_tensor(mtf_samples)
    predictions = {
        "outputs": outputs
    }
    return tpu_estimator.TPUEstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        prediction_hooks=[mtf.MtfRestoreHook(lowering)])

  targets = import_feature(features, mesh, "targets")
  anon_targets = mtf.anonymize(targets)
  logits, loss = model.call_simple(
      inputs=inputs,
      targets=targets,
      compute_loss=True,
      mode=mode,
      variable_dtype=variable_dtype,
      encoder_sequence_id=import_feature(features, mesh, "inputs_segmentation"),
      decoder_sequence_id=import_feature(
          features, mesh, "targets_segmentation"),
      encoder_position=import_feature(features, mesh, "inputs_position"),
      decoder_position=import_feature(features, mesh, "targets_position")
  )

  if use_tpu and logits is not None:
    logits = mtf.anonymize(logits)

  # TRAIN mode
  if mode == tf.estimator.ModeKeys.TRAIN:
    var_grads = mtf.gradients(
        [loss], [v.outputs[0] for v in graph.trainable_variables])
    optimizer = mtf.optimize.AdafactorOptimizer()
    update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)

  lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=FLAGS.autostack)

  tf_loss = lowering.export_to_tf_tensor(loss)
  tf_loss = tf.to_float(tf_loss)
  if not use_tpu:
    tf_loss = tf.Print(
        tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss")
  if logits and mode != tf.estimator.ModeKeys.TRAIN:
    tf_logits = lowering.export_to_tf_tensor(logits)

  if mode == tf.estimator.ModeKeys.TRAIN:
    tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
    tf_update_ops.append(tf.assign_add(global_step, 1))
    train_op = tf.group(tf_update_ops)

  with mtf.utils.outside_all_rewrites():
    # Copy master variables to slices. Must be called first.
    restore_hook = mtf.MtfRestoreHook(lowering)
    saver = tf.train.Saver(
        tf.global_variables(),
        sharded=True,
        max_to_keep=10,
        keep_checkpoint_every_n_hours=2,
        defer_build=False,
        save_relative_paths=True)
    tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
    saver_listener = mtf.MtfCheckpointSaverListener(lowering)
    saver_hook = tf.train.CheckpointSaverHook(
        FLAGS.model_dir,
        save_steps=1000,
        saver=saver,
        listeners=[saver_listener])

    if mode == tf.estimator.ModeKeys.TRAIN:
      if use_tpu:
        return tpu_estimator.TPUEstimatorSpec(
            mode=tf.estimator.ModeKeys.TRAIN,
            loss=tf_loss,
            train_op=train_op,
            training_hooks=[restore_hook, saver_hook])
      else:
        return tf.estimator.EstimatorSpec(
            tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
            training_chief_hooks=[restore_hook, saver_hook])
    elif mode == tf.estimator.ModeKeys.EVAL:
      def padded_neg_log_perplexity(logits, labels):
        weights = tf.to_float(tf.not_equal(labels, 0))
        xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=labels, logits=logits)
        return {"neg_log_perplexity": tf.metrics.mean(-xent, weights)}
      labels = lowering.export_to_tf_tensor(anon_targets)
      eval_metrics = (padded_neg_log_perplexity, [tf_logits, labels])
      return tpu_estimator.TPUEstimatorSpec(
          tf.estimator.ModeKeys.EVAL,
          evaluation_hooks=[restore_hook],
          loss=tf_loss,
          eval_metrics=eval_metrics)