Exemple #1
0
def test_entmax():
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    length = mtf.Dimension("tensor_length", 8)
    tensor = mtf.range(mesh, length, tf.float32)
    output = entmax(tensor)
    grad = mtf.gradients([output], [tensor])[0]
    sample = sample_categorical(output, length)

    mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    sample = lowering.export_to_tf_tensor(sample)
    grad = lowering.export_to_tf_tensor(grad)
Exemple #2
0
def test_model():
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")

    seq_len = params["n_ctx"]

    batch_dim = mtf.Dimension("batch", 1)
    sequence_dim = mtf.Dimension("sequence", seq_len)

    features = {
        'inputs': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)),
                           tf.int32),
        'labels': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)),
                           tf.int32)
    }

    # create mask

    num_mem_kv = params.get('num_mem_kv', 0)
    length_dim = mtf.Dimension('sequence', seq_len)
    memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv)
    embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len)
    embd_dim = mtf.Dimension("embd", params["n_embd"])
    vocab_dim = mtf.Dimension("vocab", params["n_vocab"])

    other_features = {}
    variable_dtype = mtf.VariableDType(tf.float32, tf.float32, tf.float32)

    other_features["attn_bias"] = biasmask_attn_weights(
        mesh, length_dim, memory_length_dim, variable_dtype)
    other_features["embd_dim"] = embd_dim
    other_features["vocab_dim"] = vocab_dim
    other_features["embed_sequence_dim"] = embed_sequence_dim
    other_features["memory_length_dim"] = memory_length_dim

    with not_raises(Exception):
        logits, _, _ = gpt2.model(features,
                                  other_features,
                                  params,
                                  mesh,
                                  variable_dtype=variable_dtype)

        mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                          layout={},
                                                          devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        logits = lowering.export_to_tf_tensor(logits)
Exemple #3
0
def test_sampling():
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")

    batch_dim = mtf.Dimension("batch", 1)
    sequence_dim = mtf.Dimension("sequence", 1)

    inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32)
    inputs = mtf.pad(inputs, [0, 3], sequence_dim.name)

    # create mask

    seq_len = params["n_ctx"]
    num_mem_kv = params.get('num_mem_kv', 0)
    length_dim = mtf.Dimension('sequence', seq_len)
    memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv)
    embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len)
    embd_dim = mtf.Dimension("embd", params["n_embd"])
    vocab_dim = mtf.Dimension("vocab", params["n_vocab"])

    other_features = {}

    other_features["attn_bias"] = biasmask_attn_weights(
        mesh, length_dim, memory_length_dim, mtf.VariableDType(tf.float32))
    other_features["embd_dim"] = embd_dim
    other_features["vocab_dim"] = vocab_dim
    other_features["embed_sequence_dim"] = embed_sequence_dim
    other_features["memory_length_dim"] = memory_length_dim

    params["mode"] = "predict"

    with not_raises(Exception):
        samples = sample_autoregressive(
            inputs,
            other_features=other_features,
            params=params,
            variable_dtype=mtf.VariableDType(),
            remove_partial_sequences=params["remove_partial_sequences"],
            stop_at_token=params["eos_id"],
            sampling_use_entmax=True)

        mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                          layout={},
                                                          devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        samples = lowering.export_to_tf_tensor(samples)