def __call__(self, x, is_training): out = self.activation(x) * self.beta if self.stride > 1: # Average-pool downsample. 이 부분이 트랜지션 블록 부분인데, 언제 self.stride > 1이 적용되는가? shortcut = hk.avg_pool(out, window_shape=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='SAME') if self.use_projection: shortcut = self.conv_shortcut(shortcut) elif self.use_projection: shortcut = self.conv_shortcut(out) else: shortcut = x out = self.conv0(out) # 1x1 out = self.conv1(self.activation(out)) # 3x3 if self.use_two_convs: out = self.conv1b(self.activation(out)) # 3x3 out = self.conv2(self.activation(out)) # 1x1 out = ( self.se(out) * 2 ) * out # Multiply by 2 for rescaling # 이것도 어떤 논문에서 2배로 하면 더 잘된다고 그랬다고 하는 ..? # Get average residual standard deviation for reporting metrics. res_avg_var = jnp.mean(jnp.var(out, axis=[0, 1, 2])) # Apply stochdepth if applicable. if self._has_stochdepth: out = self.stoch_depth(out, is_training) # SkipInit Gain out = out * hk.get_parameter( 'skip_gain', (), out.dtype, init=jnp.zeros) return out * self.alpha + shortcut, res_avg_var
def __call__(self, x, is_training): out = jax.nn.relu(self.bn1(self.conv1(x), is_training=is_training)) out = self.layer1(out, is_training=is_training) out = self.layer2(out, is_training=is_training) out = self.layer3(out, is_training=is_training) out = hk.avg_pool(out, window_shape=(1, 1, out.shape[2], out.shape[3]), strides=(1, 1, out.shape[2], out.shape[3]), padding='SAME') out = out.reshape(out.shape[0], -1) out = self.linear(out) return out
o2 = token_embedding_map(x) o2 = jnp.reshape(o2, (o2.shape[1], o2.shape[0], o2.shape[2])) # LSTM Part of Network core = hk.LSTM(100) if args and args.dynamic_unroll: outs, state = hk.dynamic_unroll(core, o2, core.initial_state(x.shape[0])) else: outs, state = hk.static_unroll(core, o2, core.initial_state(x.shape[0])) outs = outs.reshape(outs.shape[1], outs.shape[0], outs.shape[2]) # Avg Pool -> Linear red_dim_outs = hk.avg_pool(outs, seq_len, seq_len, "SAME").squeeze() final_layer = hk.Linear(2) ret = final_layer(red_dim_outs) return ret def embedding_model(arr, vocab_size=10_000, seq_len=256, **_): x = arr embed_init = hk.initializers.TruncatedNormal(stddev=0.02) token_embedding_map = hk.Embed(vocab_size + 4, 16, w_init=embed_init) # no religion on earth can justify the different behavior betweenc # jax and pytorch # o2 = jnp.transpose(o1, (0, 2, 1)) # here the dimensions would be [b, seq_dim, embed_out_dim]