Exemplo n.º 1
0
def vision_transformer(x, input_res, patch_size, v_width, v_layers, v_heads,
                       embed_dim):
    scale = v_width**-0.5

    with nn.parameter_scope("visual"):
        con1_w = nn.parameter.get_parameter_or_create(name="conv1/W",
                                                      shape=(v_width, 3,
                                                             patch_size,
                                                             patch_size))
        x = F.convolution(
            x, con1_w, bias=None,
            stride=(patch_size, patch_size))  # shape = [*, width, grid, grid]

        # shape = [*, width, grid ** 2]
        x = F.reshape(x, (x.shape[0], x.shape[1], -1))
        x = F.transpose(x, (0, 2, 1))  # shape = [*, grid ** 2, width]

        z = np.zeros((x.shape[0], 1, x.shape[-1]))
        zeros = nn.Variable.from_numpy_array(z)
        class_embed = nn.parameter.get_parameter_or_create(
            name="class_embedding", shape=(v_width, )).reshape(
                (x.shape[0], 1, v_width))
        # shape = [*, grid ** 2 + 1, width]
        x = F.concatenate(class_embed + zeros, x, axis=1)

        positional_embedding = nn.parameter.get_parameter_or_create(
            name='positional_embedding',
            shape=((input_res // patch_size)**2 + 1, v_width)).reshape(
                (x.shape[0], x.shape[1], v_width))
        x = x + positional_embedding

        ln_pre_w = nn.parameter.get_parameter_or_create(
            name="ln_pre/W", shape=(v_width, )).reshape((1, 1, v_width))
        ln_pre_b = nn.parameter.get_parameter_or_create(
            name="ln_pre/b", shape=(v_width, )).reshape((1, 1, v_width))
        x = F.layer_normalization(x, ln_pre_b, ln_pre_w, batch_axis=(0, 1))

        x = F.transpose(x, (1, 0, 2))  # NLD -> LND

        x = transformer(x, v_width, v_layers, v_heads)

        x = F.transpose(x, (1, 0, 2))  # LND -> NLD

        ln_post_w = nn.parameter.get_parameter_or_create(
            name="ln_post/W", shape=(v_width, )).reshape((1, 1, v_width))
        ln_post_b = nn.parameter.get_parameter_or_create(
            name="ln_post/b", shape=(v_width, )).reshape((1, 1, v_width))
        x = F.slice(x, stop=(x.shape[0], 1, x.shape[2]))
        x = F.layer_normalization(x, ln_post_b, ln_post_w)

        if 'proj' in nn.get_parameters():
            visual_proj = nn.parameter.get_parameter_or_create(
                name="proj", shape=(v_width, embed_dim)).reshape(
                    (1, v_width, -1))
            x = F.batch_matmul(x, visual_proj)

        x = x.reshape((-1, embed_dim))

    return x
Exemplo n.º 2
0
def test_layer_normalization_forward_backward(seed, x_shape, batch_axis, output_stat):
    rng = np.random.RandomState(seed)
    input = rng.randn(*x_shape).astype(np.float32)

    stat_shape = tuple([x_shape[i] if i in _force_list(batch_axis) else 1
                        for i in range(len(x_shape))])

    beta = rng.randn(*stat_shape).astype(np.float32)
    gamma = rng.randn(*stat_shape).astype(np.float32)
    eps = 1e-05

    x = nn.Variable.from_numpy_array(input)
    v_beta = nn.Variable.from_numpy_array(beta)
    v_gamma = nn.Variable.from_numpy_array(gamma)

    output = F.layer_normalization(
        x, v_beta, v_gamma, batch_axis, eps, output_stat)
    ref = ref_layer_normalization(
        input, beta, gamma, batch_axis, eps, output_stat)

    if output_stat:
        tmp = F.sink(*output)
        tmp.forward()
        tmp.backward()

        for o, r in zip(output, ref):
            assert o.shape == r.shape
            assert np.allclose(o.d, r, atol=1e-2, rtol=1e-5)

    else:
        output.forward()
        output.backward()

        assert np.allclose(output.d, ref, atol=1e-2, rtol=1e-5)
Exemplo n.º 3
0
def layernorm(x, i, d_model):
    weight = nn.parameter.get_parameter_or_create(name=f"ln_{i}/W",
                                                  shape=(d_model, )).reshape(
                                                      (1, 1, d_model))
    bias = nn.parameter.get_parameter_or_create(name=f"ln_{i}/b",
                                                shape=(d_model, )).reshape(
                                                    (1, 1, d_model))

    return F.layer_normalization(x, bias, weight, batch_axis=(0, 1))
Exemplo n.º 4
0
def encode_text(text):
    param_dict = nn.get_parameters()

    embed_dim = param_dict['text_projection'].shape[1]
    context_length = param_dict['positional_embedding'].shape[0]
    vocab_size = param_dict['token_embedding/W'].shape[0]
    transformer_width = param_dict['ln_final/W'].shape[0]
    transformer_heads = transformer_width // 64
    transformer_layers = len(
        set(
            k.split('/')[2] for k in param_dict.keys()
            if k.startswith(f'transformer/resblocks')))

    token_embedding = nn.parameter.get_parameter_or_create(
        name='token_embedding/W', shape=(vocab_size, transformer_width))
    x = F.embed(text, token_embedding)  # [batch_size, n_ctx, d_model]

    positional_embedding = nn.parameter.get_parameter_or_create(
        name='positional_embedding',
        shape=(context_length, transformer_width)).reshape(
            (1, context_length, transformer_width))
    x = x + positional_embedding

    x = F.transpose(x, (1, 0, 2))  # NLD -> LND

    x = transformer(x,
                    transformer_width,
                    transformer_layers,
                    transformer_heads,
                    attn_mask=build_attn_mask(context_length))

    x = F.transpose(x, (1, 0, 2))  # LND -> NLD

    ln_final_W = nn.parameter.get_parameter_or_create(
        name='ln_final/W', shape=(transformer_width, )).reshape(
            (1, 1, transformer_width))
    ln_final_b = nn.parameter.get_parameter_or_create(
        name='ln_final/b', shape=(transformer_width, )).reshape(
            (1, 1, transformer_width))
    x = F.layer_normalization(x, ln_final_b, ln_final_W, batch_axis=(0, 1))

    idx = F.max(text, axis=-1, only_index=True)
    idx.forward()
    x = x[list(range(x.shape[0])), idx.d].reshape((1, x.shape[0], -1))
    text_projection = nn.parameter.get_parameter_or_create(
        name='text_projection', shape=(transformer_width, embed_dim)).reshape(
            (1, transformer_width, embed_dim))
    x = F.batch_matmul(x, text_projection)

    x = x.reshape((-1, embed_dim))

    return x