Ejemplo n.º 1
0
 def __call__(self, x, is_training):
     out = self.activation(x) * self.beta
     if self.stride > 1:  # Average-pool downsample. 이 부분이 트랜지션 블록 부분인데, 언제 self.stride > 1이 적용되는가?
         shortcut = hk.avg_pool(out,
                                window_shape=(1, 2, 2, 1),
                                strides=(1, 2, 2, 1),
                                padding='SAME')
         if self.use_projection:
             shortcut = self.conv_shortcut(shortcut)
     elif self.use_projection:
         shortcut = self.conv_shortcut(out)
     else:
         shortcut = x
     out = self.conv0(out)  # 1x1
     out = self.conv1(self.activation(out))  # 3x3
     if self.use_two_convs:
         out = self.conv1b(self.activation(out))  # 3x3
     out = self.conv2(self.activation(out))  # 1x1
     out = (
         self.se(out) * 2
     ) * out  # Multiply by 2 for rescaling # 이것도 어떤 논문에서 2배로 하면 더 잘된다고 그랬다고 하는 ..?
     # Get average residual standard deviation for reporting metrics.
     res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2]))
     # Apply stochdepth if applicable.
     if self._has_stochdepth:
         out = self.stoch_depth(out, is_training)
     # SkipInit Gain
     out = out * hk.get_parameter(
         'skip_gain', (), out.dtype, init=jnp.zeros)
     return out * self.alpha + shortcut, res_avg_var
Ejemplo n.º 2
0
 def __call__(self, x, is_training):
     out = jax.nn.relu(self.bn1(self.conv1(x), is_training=is_training))
     out = self.layer1(out, is_training=is_training)
     out = self.layer2(out, is_training=is_training)
     out = self.layer3(out, is_training=is_training)
     out = hk.avg_pool(out,
                       window_shape=(1, 1, out.shape[2], out.shape[3]),
                       strides=(1, 1, out.shape[2], out.shape[3]),
                       padding='SAME')
     out = out.reshape(out.shape[0], -1)
     out = self.linear(out)
     return out
Ejemplo n.º 3
0
    o2 = token_embedding_map(x)

    o2 = jnp.reshape(o2, (o2.shape[1], o2.shape[0], o2.shape[2]))

    # LSTM Part of Network
    core = hk.LSTM(100)
    if args and args.dynamic_unroll:
        outs, state = hk.dynamic_unroll(core, o2,
                                        core.initial_state(x.shape[0]))
    else:
        outs, state = hk.static_unroll(core, o2,
                                       core.initial_state(x.shape[0]))
    outs = outs.reshape(outs.shape[1], outs.shape[0], outs.shape[2])

    # Avg Pool -> Linear
    red_dim_outs = hk.avg_pool(outs, seq_len, seq_len, "SAME").squeeze()
    final_layer = hk.Linear(2)
    ret = final_layer(red_dim_outs)

    return ret


def embedding_model(arr, vocab_size=10_000, seq_len=256, **_):
    x = arr
    embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
    token_embedding_map = hk.Embed(vocab_size + 4, 16, w_init=embed_init)

    # no religion on earth can justify the different behavior betweenc
    # jax and pytorch
    # o2 = jnp.transpose(o1, (0, 2, 1))
    # here the dimensions would be [b, seq_dim, embed_out_dim]