コード例 #1
0
ファイル: test.py プロジェクト: AdityaGudimella/tsalib
def test_warp():
    B, T, D = get_dim_vars('b t d')

    x: 'btd' = np.ones((B, T, D))

    # two view transformations (reshapes) in sequence
    x1 = warp(x, 'btd -> b,t,4,d//4 -> b*t,4,d//4', 'vv', debug=False)
    assert (x1.shape == (B * T, 4, D // 4))

    # four reshapes in sequence
    x2 = warp(x,
              'btd -> b,t,4,d//4 -> b*t,4,d//4 -> b,t,4,d//4 -> btd',
              'vvvv',
              debug=False)
    assert (x2.shape == (B, T, D))

    # Same reshape sequence in shorthand, specified as list of transformations
    x2 = warp(x, [
        '__d -> ,,4,d//4', 'b,t,, -> b*t,,', 'b*t,, -> b,t,,',
        ',,4,d//4 -> ,,d'
    ],
              'vvvv',
              debug=True)
    assert (x2.shape == (B, T, D))

    print('test_warp: all assertions hold')
コード例 #2
0
ファイル: test.py プロジェクト: AdityaGudimella/tsalib
def test_reshape():
    B, T, D = get_dim_vars('b t d')
    x: (B, T, D) = np.ones((B, T, D))
    h = 4
    x: (B, T, h, D // h) = x.reshape((B, T, h, D // h))
    assert x.shape == (B, T, h, D // h)
    print('test_reshape: all assertions hold')
コード例 #3
0
def test_foo():
    from tsalib import get_dim_vars

    # get the declared dimension sizes: 10, 100, 1024
    B, L, D = get_dim_vars('b t d')
    #x = tf.get_variable("x", [B, L, D])
    x = tf.Variable(tf.zeros([B, L, D]))
    foo(x)
コード例 #4
0
ファイル: test_pytorch.py プロジェクト: victor8733/tsanley
def test_foo():
    import torch
    from tsalib import get_dim_vars

    # get the declared dimension sizes: 10, 100, 1024
    B, L, D = get_dim_vars('b t d')
    x = torch.Tensor(B, L, D)
    foo(x)
コード例 #5
0
ファイル: test_pytorch.py プロジェクト: victor8733/tsanley
def test_func():
    import torch
    from tsalib import get_dim_vars

    B, L, D = get_dim_vars('b t d')
    x = torch.Tensor(B, L, D)
    f1(x)  #error
    f2(x)  #success
コード例 #6
0
ファイル: test.py プロジェクト: AdityaGudimella/tsalib
def test_expand_short():
    B, T, D, K = get_dim_vars('b t d k')

    x: 'btd' = np.ones((B, T, D))
    x: 'bktd' = x[:, None]
    expand_shape = et(src=(B, K, T, D), expansions='k->k*5', in_shape=x.shape)
    assert expand_shape == (-1, 5, -1, -1)
    print('test_expand_short: all assertions hold')
コード例 #7
0
ファイル: test.py プロジェクト: AdityaGudimella/tsalib
def warp_long1():
    B, T, D, C = get_dim_vars('b t d c')
    x1: 'btd' = np.ones((B, T, D))
    x2: 'btd' = np.ones((B, T, D))
    x3: 'btd' = np.ones((B, T, D))
    y = warp([x1, x2, x3], '(btd)* -> btdc -> bdtc -> b,d//2,t*2,c', 'jpv')
    assert y.shape == (B, D // 2, T * 2, C)
    print('warp_long1: all assertions hold')
コード例 #8
0
ファイル: test.py プロジェクト: AdityaGudimella/tsalib
def test_permute():
    B, T, D, K = get_dim_vars('b t d k')
    x: (B, T, D, K) = np.ones((B, T, D, K))
    perm_indices = _pt(src=(B, T, D, K), to=(D, T, B, K))
    assert perm_indices == (2, 1, 0, 3)
    x = x.transpose(perm_indices)
    assert x.shape == (D, T, B, K)
    print('test_permute: all assertions hold')
コード例 #9
0
ファイル: test.py プロジェクト: AdityaGudimella/tsalib
def test_dot():
    B, C, T, D = get_dim_vars('b c t d')
    #x = np.random.rand(B, C, T)
    #y = np.random.rand(C, D)
    x = torch.randn(B, C, T)
    y = torch.randn(C, D)
    z = dot('_c_.c_', x, y)
    assert z.shape == (B, T, D)
    print('test_dot: all assertions passed')
コード例 #10
0
ファイル: test.py プロジェクト: AdityaGudimella/tsalib
def test_warp_pytorch():
    B, T, D = get_dim_vars('b t d')

    import torch
    y: 'btd' = torch.randn(B, T, D)
    #a reshape followed by permute
    y = warp(y, 'btd -> b,t,4,d//4 -> b,4,t,d//4', 'vp', debug=False)
    assert (y.shape == (B, 4, T, D // 4))

    print('test_warp_pytorch: all assertions hold')
コード例 #11
0
def embedding_lookup(input_ids,
                     vocab_size,
                     embedding_size=128,
                     initializer_range=0.02,
                     word_embedding_name="word_embeddings",
                     use_one_hot_embeddings=False):
    """Looks up words embeddings for id tensor.

  Args:
    input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
      ids.
    vocab_size: int. Size of the embedding vocabulary.
    embedding_size: int. Width of the word embeddings.
    initializer_range: float. Embedding initialization range.
    word_embedding_name: string. Name of the embedding table.
    use_one_hot_embeddings: bool. If True, use one-hot method for word
      embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better
      for TPUs.

  Returns:
    float Tensor of shape [batch_size, seq_length, embedding_size].
  """
    # This function assumes that the input is of shape [batch_size, seq_length,
    # num_inputs].
    #
    # If the input is a 2D tensor of shape [batch_size, seq_length], we
    # reshape to [batch_size, seq_length, 1].
    if input_ids.shape.ndims == 2:
        input_ids = tf.expand_dims(input_ids, axis=[-1])

    B, T, D = get_dim_vars('b t d')

    input_ids: 'bti'  #i : num of inputs
    #TODO: define/pickup i from input_ids
    i = get_shape_list(input_ids)[-1]

    embedding_table: 'vd' = tf.get_variable(
        name=word_embedding_name,
        shape=[vocab_size, embedding_size],
        initializer=create_initializer(initializer_range))

    if use_one_hot_embeddings:
        flat_input_ids: 'b*t*i' = tf.reshape(input_ids, [-1])
        one_hot_input_ids: 'b*t*i,v' = tf.one_hot(flat_input_ids,
                                                  depth=vocab_size)
        output: 'b*t*i,d' = tf.matmul(one_hot_input_ids, embedding_table)
    else:
        output = tf.nn.embedding_lookup(embedding_table, input_ids)

    #input_shape: 'bti' = get_shape_list(input_ids)

    output: 'btd' = warp(output, tfms=f'b*t*{i},d -> b,t,d*{i}', tfm_names='r')

    return (output, embedding_table)
コード例 #12
0
ファイル: test.py プロジェクト: AdityaGudimella/tsalib
def test_expand():
    B, T, D, K = get_dim_vars('b t d k')

    x: (B, T, D) = np.ones((B, T, D))
    x: (B, K, T, D) = x[:, None]

    expand_shape = et(src=(B, K, T, D),
                      expansions=[(K, K * 5)],
                      in_shape=x.shape)  #(B, K, T, D) -> (B, K*5, T, D)
    assert expand_shape == (-1, 5, -1, -1)
    print('test_expand: all assertions hold')
コード例 #13
0
ファイル: test.py プロジェクト: AdityaGudimella/tsalib
def test_numpy():
    print('\nTest usage with numpy ..')
    B, D = get_dim_vars('b d')
    import numpy as np
    a: (B, D) = np.zeros((B, D))
    print(f'original array: {(B,D)}: {a.shape}')

    b: (2, B, D) = np.stack([a, a])
    print(f'after stack: {(2,B,D)}: {b.shape}')

    ax = (2, B, D).index(B)
    c: (2, D) = np.mean(b, axis=ax)
    print(f'after mean along axis = {ax}: {(2,D)}: {c.shape}')
コード例 #14
0
ファイル: test.py プロジェクト: AdityaGudimella/tsalib
def test_permute_short():
    B, T, D, K, C, H, W = get_dim_vars('b t d k c h w')
    x: (B, T, D, K) = np.ones((B, T, D, K))
    x = x.transpose(pt('btdk -> dtbk'))  # (B, T, D, K) -> (D, T, B, K)
    assert x.shape == (D, T, B, K)

    x = x.transpose(pt('d_b_ -> b_d_'))  # (D,T,B,K) -> (B, T, D, K)
    assert x.shape == (B, T, D, K)

    x: (B, C, H, W) = np.ones((B, C, H, W))
    x1 = x.transpose(pt(',c,, -> ,,,c'))
    assert x1.shape == (B, H, W, C)
    print('test_permute_short: all assertions hold')
コード例 #15
0
ファイル: test.py プロジェクト: AdityaGudimella/tsalib
def test_reshape_short():
    B, T, D = get_dim_vars('b t d')
    x: (B, T, D) = np.ones((B, T, D))
    h = 4
    x = x.reshape(vt(f'btd -> b,t,{h},d//{h}', x.shape))
    assert x.shape == (B, T, h, D // h)

    x1 = x.reshape(vt('b,t,4,k -> b*t,4,k', x.shape))
    assert x1.shape == (B * T, h, D // h)

    x1 = x.reshape(vt('b,t,, -> b*t,,', x.shape))
    assert x1.shape == (B * T, h, D // h)

    print('test_reshape_short: all assertions hold')
コード例 #16
0
ファイル: test.py プロジェクト: AdityaGudimella/tsalib
def test_join_transform():
    B, T, D = get_dim_vars('b t d')
    x1: 'btd' = np.ones((B, T, D))
    x2: 'btd' = np.ones((B, T, D))
    x3: 'btd' = np.ones((B, T, D))

    dims = join_transform([x1, x2, x3], '(b,t,d)* -> b,3*t,d')
    assert dims == ',*,'
    #now use backend-dependent join

    dims = join_transform([x1, x2, x3], '(b,t,d)* -> b,^,t,d')
    assert dims == ',^,,'
    #now use backend-dependent join

    print('test_join_transform: all assertions passed')
コード例 #17
0
ファイル: snippets_pytorch.py プロジェクト: vtmounica/tsalib
def tsa_attn(Y, ht, rt1):
    B, L, D = get_dim_vars('b l d')
    Y: 'bld' ; ht: 'b,d'; rt1: 'b,d'

    #bM, br, w: 'd,'
    #WY, Wh, Wr, Wt: 'd,d' 
    (bM, br, w), (WY, Wh, Wr, Wt) = make_params(D)

    tmp: 'bd' = dot('_d.d_', ht, Wh) + dot('_d.d_', rt1, Wr)
    tmpa: 'bld' = alignto((tmp,'bd'), 'bld')

    Mt: 'bld' = torch.tanh(dot('__d.d_', Y, WY) + tmpa + bM)
    at: 'bl' = F.softmax(dot('__d.d', Mt, w), dim=-1)
    rt: 'bd' = dot('bld,bl->bd', Y, at) + torch.tanh(dot('_d.d_', rt1, Wt) + br)

    return rt, at
コード例 #18
0
def test_decls():
    print('\nTest declarations ..')
    #local declarations
    print(f'B, C, D = {_B}, {_C}, {_D}')

    #strict=False allows overwriting previous declarations
    H, W = dim_vars('Height(h):256 Width(w):256', exists_ok=True)
    print(f'H, W = {H}, {W}')

    # test update dim var len

    H.update_len(1024)
    print(f'H = {H}')

    update_dim_vars_len({'h': 512, 'w': 128})
    H, W = get_dim_vars('h w')
    print(f'H, W = {H}, {W}')
コード例 #19
0
ファイル: test.py プロジェクト: AdityaGudimella/tsalib
def test_pytorch():
    print('\nTest usage with pytorch ..')
    B, D = get_dim_vars('b d')
    B, D = dim_vars('Batch:2 EmbedDim:3', exists_ok=True)
    import torch

    a = torch.Tensor([[1., 2., 4.], [3., 6., 9.]])
    assert a.size() == (B, D)

    b = torch.stack([a, a])

    print('Asserting b.size() == (2,B,D)')
    assert b.size() == (2, B, D)

    c = torch.cat([a, a], dim=1)
    print('Assertion on c.size()')
    assert c.size() == (B, D * 2)
コード例 #20
0
ファイル: test.py プロジェクト: AdityaGudimella/tsalib
def test_join():
    B, T, D = get_dim_vars('b t d')
    x1: 'btd' = np.ones((B, T, D))
    x2: 'btd' = np.ones((B, T, D))
    x3: 'btd' = np.ones((B, T, D))

    #concatenate along the (T) dimension: (b,t,d)* -> (b,3*t,d)
    x = join([x1, x2, x3], dims=',*,')
    assert x.shape == (B, 3 * T, D)

    #stack: join by adding a new dimension to the front: (b,t,d)* -> (^,b,t,d)
    x = join([x1, x2, x3], dims='^')
    assert x.shape == (3, B, T, D)

    #stack by adding a new dimension at second position: (b,t,d)* -> b,^,t,d)
    x = join([x1, x2, x3], dims=',^')
    assert x.shape == (B, 3, T, D)
    print('test_join: all assertions passed')
コード例 #21
0
ファイル: resnet.py プロジェクト: victor8733/tsanley
    def forward(self, x): #H = W = 224
        x: 'b,3,h,w'
        x: 'b,64,h//2,w//2' = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x: 'b,64,h//4,w//4' = self.maxpool(x)

        x: 'b,64*e,h//4,w//4' = self.layer1(x)
        x: 'b,128*e,h//8,w//8' = self.layer2(x)
        x: 'b,256*e,h//16,w//16' = self.layer3(x)
        x: 'b,512*e,h//32,w//32' = self.layer4(x)

        x: 'b,512*e,1,1' = self.avgpool(x)
        B, Ex = get_dim_vars('b e')
        x: 'b,512*e' = x.view(B, 512*Ex)
        x: 'b,nc' = self.fc(x)

        return x
コード例 #22
0
ファイル: test.py プロジェクト: AdityaGudimella/tsalib
def test_cast_int():
    print('\nTest integer cast ..')
    B, C = get_dim_vars('b c')
    x = np.zeros((B, C))
    print(f'shape of array: ({B},{C}): {x.shape}')
    return x
コード例 #23
0
ファイル: test.py プロジェクト: AdityaGudimella/tsalib
def test_arith():
    print('\nTest arithmetic ..')
    _K, _W, _B, _H = get_dim_vars('k w b h')
    _K = _W * 2
    h = 4
    print((h, _H // h, _K, _B * 2))
コード例 #24
0
ファイル: test.py プロジェクト: AdityaGudimella/tsalib
def warp_long2():
    B, T, D, C = get_dim_vars('b t d c')
    x1: 'btd' = np.ones((B, T, D))
    y = warp(x1, 'btd -> btd1 -> bdt1 -> b,d//2,t*2,1', 'apv')
    assert y.shape == (B, D // 2, T * 2, 1)
    print('warp_long2: all assertions hold')
コード例 #25
0
def embedding_postprocessor(input_tensor: 'btd',
                            use_token_type=False,
                            token_type_ids: 'bt' = None,
                            token_type_vocab_size=16,
                            token_type_embedding_name="token_type_embeddings",
                            use_position_embeddings=True,
                            position_embedding_name="position_embeddings",
                            initializer_range=0.02,
                            max_position_embeddings=512,
                            dropout_prob=0.1):
    """Performs various post-processing on a word embedding tensor.

  Args:
    input_tensor: float Tensor of shape [batch_size, seq_length,
      embedding_size].
    use_token_type: bool. Whether to add embeddings for `token_type_ids`.
    token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
      Must be specified if `use_token_type` is True.
    token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
    token_type_embedding_name: string. The name of the embedding table variable
      for token type ids.
    use_position_embeddings: bool. Whether to add position embeddings for the
      position of each token in the sequence.
    position_embedding_name: string. The name of the embedding table variable
      for positional embeddings.
    initializer_range: float. Range of the weight initialization.
    max_position_embeddings: int. Maximum sequence length that might ever be
      used with this model. This can be longer than the sequence length of
      input_tensor, but cannot be shorter.
    dropout_prob: float. Dropout probability applied to the final output tensor.

  Returns:
    float tensor with same shape as `input_tensor`.

  Raises:
    ValueError: One of the tensor shapes or input values is invalid.
  """
    #input_shape = get_shape_list(input_tensor, expected_rank=3)
    #batch_size = input_shape[0]
    #seq_length = input_shape[1]
    #width = input_shape[2]

    B, T, D = int_shape(get_dim_vars('b t d'))
    batch_size, seq_length, width = B, T, D
    size_assert(get_shape_list(input_tensor), (B, T, D))

    output: 'btd' = input_tensor

    if use_token_type:
        if token_type_ids is None:
            raise ValueError("`token_type_ids` must be specified if"
                             "`use_token_type` is True.")
        token_type_table: 'tv,d' = tf.get_variable(
            name=token_type_embedding_name,
            shape=[token_type_vocab_size, width],
            initializer=create_initializer(initializer_range))
        # This vocab will be small so we always do one-hot here, since it is always
        # faster for a small vocabulary.
        flat_token_type_ids = tf.reshape(token_type_ids, [-1])
        one_hot_ids = tf.one_hot(flat_token_type_ids,
                                 depth=token_type_vocab_size)
        token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
        token_type_embeddings: 'btd' = tf.reshape(token_type_embeddings,
                                                  (B, T, D))
        #[batch_size, seq_length, width])
        output += token_type_embeddings

    if use_position_embeddings:
        assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
        with tf.control_dependencies([assert_op]):
            full_position_embeddings = tf.get_variable(
                name=position_embedding_name,
                shape=[max_position_embeddings, width],
                initializer=create_initializer(initializer_range))
            # Since the position embedding table is a learned variable, we create it
            # using a (long) sequence length `max_position_embeddings`. The actual
            # sequence length might be shorter than this, for faster training of
            # tasks that do not have long sequences.
            #
            # So `full_position_embeddings` is effectively an embedding table
            # for position [0, 1, 2, ..., max_position_embeddings-1], and the current
            # sequence has positions [0, 1, 2, ... seq_length-1], so we can just
            # perform a slice.
            position_embeddings = tf.slice(full_position_embeddings, [0, 0],
                                           [seq_length, -1])
            num_dims = len(output.shape.as_list())

            # Only the last two dimensions are relevant (`seq_length` and `width`), so
            # we broadcast among the first dimensions, which is typically just
            # the batch size.
            position_broadcast_shape = []
            for _ in range(num_dims - 2):
                position_broadcast_shape.append(1)
            position_broadcast_shape.extend([T, D])
            position_embeddings = tf.reshape(position_embeddings,
                                             position_broadcast_shape)
            output += position_embeddings

    output: 'btd' = layer_norm_and_dropout(output, dropout_prob)
    return output
コード例 #26
0
ファイル: solver.py プロジェクト: firekind/style-transfer
from typing import List, Union, Callable

import torch
import torch.optim as optim
from PIL import Image
from pkbar import Kbar
from torchvision.models import vgg19
from tsalib import get_dim_vars

from .data import get_processing_transforms
from .losses import content_loss, style_loss
from .model import StyleTransferer
from .utils import ModelTargets

B, C, H, W = get_dim_vars("B C H W")


def transfer_style(
    content_image_path: str = None,
    style_image_path: str = None,
    style_image: Image = None,
    content_image: Image = None,
    cuda: bool = True,
    image_size: int = 512,
    epochs: int = 16,
    content_weight: Union[float, List[float]] = 1,
    style_weight: Union[float, List[float]] = 1000000,
    extractor_mean: List[float] = (0.40760392, 0.45795686, 0.48501961),
    extractor_std: List[float] = (1, 1, 1),
    content_layers: List[str] = ("conv_4_2", ),
    style_layers: List[str] = ("conv_1_1", "conv_2_1", "conv_3_1", "conv_4_1",
コード例 #27
0
import tensorflow as tf
import tensorflow_addons as tfa
from tsalib import get_dim_vars

Layer = tf.keras.layers.Layer
GroupNormalization = tfa.layers.GroupNormalization
Activation = tf.keras.layers.Activation
Add = tf.keras.layers.Add
Conv3D = tf.keras.layers.Conv3D
UpSampling3D = tf.keras.layers.UpSampling3D
B, C, H, W, D = get_dim_vars('B C H W D')


def BlueBlock(filters: int) -> Layer:
    return Conv3D(
        filters=filters,
        kernel_size=(3, 3, 3),
        strides=1,
        padding='same',
        data_format="channels_first",
    )


def DownSample(filters: int) -> Layer:
    return Conv3D(
        filters=filters,
        kernel_size=(3, 3, 3),
        strides=2,
        padding='same',
        data_format="channels_first",
    )
コード例 #28
0
import tensorflow as tf
from tsalib import get_dim_vars

K = tf.keras.backend
B, C, H, W, D = get_dim_vars("B C H W D")


def dice_coefficient(y_true: (B, C, H, W, D),
                     y_pred: (B, C, H, W, D)) -> tf.float32:
    """
    Calculates the dice coefficient used for metrics.
    
    Args:
        y_true ((B, C, H, W, D)): The true values of the output
        y_pred ((B, C, H, W, D)): The predicted values of the model
    
    Returns:
        tf.float32: The dice coefficient
    """
    # calculating the numerator
    # noinspection PyTypeChecker
    intersection: (B, C) = K.sum(K.abs(y_true * y_pred), axis=[-3, -2, -1])

    # calculating the denominator
    denominator: (B, C) = K.sum(K.square(y_true) + K.square(y_pred),
                                axis=[-3, -2, -1]) + 1e-8

    # dividing the two and taking mean across all the channels
    return K.mean(2 * intersection / denominator, axis=[0, 1])
コード例 #29
0
ファイル: model.py プロジェクト: firekind/style-transfer
from collections import namedtuple
from typing import List, Tuple

import torch
import torch.nn as nn
from tsalib import get_dim_vars

B, C, H, W = get_dim_vars('B C H W')


class Normalization(nn.Module):
    def __init__(self, mean: (C, ), std: (C, )):
        """
        Module that normalizes the input using the given mean and std.

        Args:
            mean (torch.Tensor): The mean. shape: (C)
            std (torch.Tensor): The standard deviation. shape (C)
        """

        super(Normalization, self).__init__()

        self.mean: (C, 1, 1) = mean.view(-1, 1, 1)
        self.std: (C, 1, 1) = std.view(-1, 1, 1)

    def forward(self, x: (B, C, H, W)) -> (B, C, H, W):
        return (x - self.mean) / self.std


class StyleTransferer(nn.Module):
    def __init__(self, content_layers: List[str], style_layers: List[str],
コード例 #30
0
    def __init__(self,
                 config,
                 is_training,
                 input_ids: 'bt',
                 input_mask: 'bt' = None,
                 token_type_ids: 'bt' = None,
                 use_one_hot_embeddings=True,
                 scope=None):
        """Constructor for BertModel.

    Args:
      config: `BertConfig` instance.
      is_training: bool. true for training model, false for eval model. Controls
        whether dropout will be applied.
      input_ids: int32 Tensor of shape [batch_size, seq_length].
      input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
      token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
      use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
        embeddings or tf.embedding_lookup() for the word embeddings. On the TPU,
        it is much faster if this is True, on the CPU or GPU, it is faster if
        this is False.
      scope: (optional) variable scope. Defaults to "bert".

    Raises:
      ValueError: The config is invalid or one of the input tensor shapes
        is invalid.
    """
        config = copy.deepcopy(config)
        if not is_training:
            config.hidden_dropout_prob = 0.0
            config.attention_probs_dropout_prob = 0.0

        #input_shape = get_shape_list(input_ids, expected_rank=2)
        #batch_size = input_shape[0]
        #seq_length = input_shape[1]
        B, T, D, N = get_dim_vars('b t d n')

        if input_mask is None:
            input_mask = tf.ones(shape=[B, T], dtype=tf.int32)

        if token_type_ids is None:
            token_type_ids = tf.zeros(shape=[B, T], dtype=tf.int32)

        with tf.variable_scope(scope, default_name="bert"):
            with tf.variable_scope("embeddings"):
                # Perform embedding lookup on the word ids.
                self.embedding_output: 'btd'
                self.embedding_table: 'vd'
                (self.embedding_output,
                 self.embedding_table) = embedding_lookup(
                     input_ids=input_ids,
                     vocab_size=config.vocab_size,
                     embedding_size=config.hidden_size,
                     initializer_range=config.initializer_range,
                     word_embedding_name="word_embeddings",
                     use_one_hot_embeddings=use_one_hot_embeddings)

                # Add positional embeddings and token type embeddings, then layer
                # normalize and perform dropout.
                self.embedding_output = embedding_postprocessor(
                    input_tensor=self.embedding_output,
                    use_token_type=True,
                    token_type_ids=token_type_ids,
                    token_type_vocab_size=config.type_vocab_size,
                    token_type_embedding_name="token_type_embeddings",
                    use_position_embeddings=True,
                    position_embedding_name="position_embeddings",
                    initializer_range=config.initializer_range,
                    max_position_embeddings=config.max_position_embeddings,
                    dropout_prob=config.hidden_dropout_prob)

            with tf.variable_scope("encoder"):
                # This converts a 2D mask of shape [batch_size, seq_length] to a 3D
                # mask of shape [batch_size, seq_length, seq_length] which is used
                # for the attention scores.
                attention_mask: 'btt' = create_attention_mask_from_input_mask(
                    input_ids, input_mask)

                # Run the stacked transformer.
                # `sequence_output` shape = [batch_size, seq_length, hidden_size].
                self.all_encoder_layers: '(btd)*' = transformer_model(
                    input_tensor=self.embedding_output,
                    attention_mask=attention_mask,
                    hidden_size=config.hidden_size,
                    num_hidden_layers=config.num_hidden_layers,
                    num_attention_heads=config.num_attention_heads,
                    intermediate_size=config.intermediate_size,
                    intermediate_act_fn=get_activation(config.hidden_act),
                    hidden_dropout_prob=config.hidden_dropout_prob,
                    attention_probs_dropout_prob=config.
                    attention_probs_dropout_prob,
                    initializer_range=config.initializer_range,
                    do_return_all_layers=True)

            self.sequence_output: 'btd' = self.all_encoder_layers[-1]
            # The "pooler" converts the encoded sequence tensor of shape
            # [batch_size, seq_length, hidden_size] to a tensor of shape
            # [batch_size, hidden_size]. This is necessary for segment-level
            # (or segment-pair-level) classification tasks where we need a fixed
            # dimensional representation of the segment.
            with tf.variable_scope("pooler"):
                # We "pool" the model by simply taking the hidden state corresponding
                # to the first token. We assume that this has been pre-trained
                #TODO: select_squeeze
                first_token_tensor: 'bd' = tf.squeeze(
                    self.sequence_output[:, 0:1, :], axis=1)
                self.pooled_output: 'bd' = tf.layers.dense(
                    first_token_tensor,
                    D,
                    activation=tf.tanh,
                    kernel_initializer=create_initializer(
                        config.initializer_range))