def vision_transformer(x, input_res, patch_size, v_width, v_layers, v_heads, embed_dim): scale = v_width**-0.5 with nn.parameter_scope("visual"): con1_w = nn.parameter.get_parameter_or_create(name="conv1/W", shape=(v_width, 3, patch_size, patch_size)) x = F.convolution( x, con1_w, bias=None, stride=(patch_size, patch_size)) # shape = [*, width, grid, grid] # shape = [*, width, grid ** 2] x = F.reshape(x, (x.shape[0], x.shape[1], -1)) x = F.transpose(x, (0, 2, 1)) # shape = [*, grid ** 2, width] z = np.zeros((x.shape[0], 1, x.shape[-1])) zeros = nn.Variable.from_numpy_array(z) class_embed = nn.parameter.get_parameter_or_create( name="class_embedding", shape=(v_width, )).reshape( (x.shape[0], 1, v_width)) # shape = [*, grid ** 2 + 1, width] x = F.concatenate(class_embed + zeros, x, axis=1) positional_embedding = nn.parameter.get_parameter_or_create( name='positional_embedding', shape=((input_res // patch_size)**2 + 1, v_width)).reshape( (x.shape[0], x.shape[1], v_width)) x = x + positional_embedding ln_pre_w = nn.parameter.get_parameter_or_create( name="ln_pre/W", shape=(v_width, )).reshape((1, 1, v_width)) ln_pre_b = nn.parameter.get_parameter_or_create( name="ln_pre/b", shape=(v_width, )).reshape((1, 1, v_width)) x = F.layer_normalization(x, ln_pre_b, ln_pre_w, batch_axis=(0, 1)) x = F.transpose(x, (1, 0, 2)) # NLD -> LND x = transformer(x, v_width, v_layers, v_heads) x = F.transpose(x, (1, 0, 2)) # LND -> NLD ln_post_w = nn.parameter.get_parameter_or_create( name="ln_post/W", shape=(v_width, )).reshape((1, 1, v_width)) ln_post_b = nn.parameter.get_parameter_or_create( name="ln_post/b", shape=(v_width, )).reshape((1, 1, v_width)) x = F.slice(x, stop=(x.shape[0], 1, x.shape[2])) x = F.layer_normalization(x, ln_post_b, ln_post_w) if 'proj' in nn.get_parameters(): visual_proj = nn.parameter.get_parameter_or_create( name="proj", shape=(v_width, embed_dim)).reshape( (1, v_width, -1)) x = F.batch_matmul(x, visual_proj) x = x.reshape((-1, embed_dim)) return x
def test_layer_normalization_forward_backward(seed, x_shape, batch_axis, output_stat): rng = np.random.RandomState(seed) input = rng.randn(*x_shape).astype(np.float32) stat_shape = tuple([x_shape[i] if i in _force_list(batch_axis) else 1 for i in range(len(x_shape))]) beta = rng.randn(*stat_shape).astype(np.float32) gamma = rng.randn(*stat_shape).astype(np.float32) eps = 1e-05 x = nn.Variable.from_numpy_array(input) v_beta = nn.Variable.from_numpy_array(beta) v_gamma = nn.Variable.from_numpy_array(gamma) output = F.layer_normalization( x, v_beta, v_gamma, batch_axis, eps, output_stat) ref = ref_layer_normalization( input, beta, gamma, batch_axis, eps, output_stat) if output_stat: tmp = F.sink(*output) tmp.forward() tmp.backward() for o, r in zip(output, ref): assert o.shape == r.shape assert np.allclose(o.d, r, atol=1e-2, rtol=1e-5) else: output.forward() output.backward() assert np.allclose(output.d, ref, atol=1e-2, rtol=1e-5)
def layernorm(x, i, d_model): weight = nn.parameter.get_parameter_or_create(name=f"ln_{i}/W", shape=(d_model, )).reshape( (1, 1, d_model)) bias = nn.parameter.get_parameter_or_create(name=f"ln_{i}/b", shape=(d_model, )).reshape( (1, 1, d_model)) return F.layer_normalization(x, bias, weight, batch_axis=(0, 1))
def encode_text(text): param_dict = nn.get_parameters() embed_dim = param_dict['text_projection'].shape[1] context_length = param_dict['positional_embedding'].shape[0] vocab_size = param_dict['token_embedding/W'].shape[0] transformer_width = param_dict['ln_final/W'].shape[0] transformer_heads = transformer_width // 64 transformer_layers = len( set( k.split('/')[2] for k in param_dict.keys() if k.startswith(f'transformer/resblocks'))) token_embedding = nn.parameter.get_parameter_or_create( name='token_embedding/W', shape=(vocab_size, transformer_width)) x = F.embed(text, token_embedding) # [batch_size, n_ctx, d_model] positional_embedding = nn.parameter.get_parameter_or_create( name='positional_embedding', shape=(context_length, transformer_width)).reshape( (1, context_length, transformer_width)) x = x + positional_embedding x = F.transpose(x, (1, 0, 2)) # NLD -> LND x = transformer(x, transformer_width, transformer_layers, transformer_heads, attn_mask=build_attn_mask(context_length)) x = F.transpose(x, (1, 0, 2)) # LND -> NLD ln_final_W = nn.parameter.get_parameter_or_create( name='ln_final/W', shape=(transformer_width, )).reshape( (1, 1, transformer_width)) ln_final_b = nn.parameter.get_parameter_or_create( name='ln_final/b', shape=(transformer_width, )).reshape( (1, 1, transformer_width)) x = F.layer_normalization(x, ln_final_b, ln_final_W, batch_axis=(0, 1)) idx = F.max(text, axis=-1, only_index=True) idx.forward() x = x[list(range(x.shape[0])), idx.d].reshape((1, x.shape[0], -1)) text_projection = nn.parameter.get_parameter_or_create( name='text_projection', shape=(transformer_width, embed_dim)).reshape( (1, transformer_width, embed_dim)) x = F.batch_matmul(x, text_projection) x = x.reshape((-1, embed_dim)) return x