def __init__(self, h): super().__init__() self.h = h self.num_kernels = len(h.resblock_kernel_sizes) self.num_upsamples = len(h.upsample_rates) self.conv_pre = hk.Conv1D(h.upsample_initial_channel, 7, 1, padding=((3, 3), )) resblock = ResBlock1 if h.resblock == '1' else ResBlock2 self.ups = [] for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): self.ups.append( hk.Conv1DTranspose(h.upsample_initial_channel // (2**(i + 1)), kernel_shape=k, stride=u, padding='SAME', name=f"ups_{i}")) self.resblocks = [] for i in range(len(self.ups)): ch = h.upsample_initial_channel // (2**(i + 1)) for j, (k, d) in enumerate( zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): self.resblocks.append( resblock(h, ch, k, d, name=f'res_block1_{len(self.resblocks)}')) self.conv_post = hk.Conv1D(1, 7, 1, padding=((3, 3), ))
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), name="resblock1"): super().__init__(name=name) self.h = h self.convs1 = [ hk.Conv1D(channels, kernel_size, 1, rate=dilation[i], padding=get_padding(kernel_size, dilation[i]), name=f'convs1_{i}') for i in range(3) ] self.convs2 = [ hk.Conv1D(channels, kernel_size, 1, rate=1, padding=get_padding(kernel_size, 1), name=f"convs2_{i}") for i in range(3) ]
def __init__(self, vocab_size, lstm_dim, dropout_rate, is_training=True): super().__init__() self.is_training = is_training self.embed = hk.Embed(vocab_size, lstm_dim) self.conv1 = hk.Conv1D(lstm_dim, 3, padding='SAME') self.conv2 = hk.Conv1D(lstm_dim, 3, padding='SAME') self.conv3 = hk.Conv1D(lstm_dim, 3, padding='SAME') self.bn1 = hk.BatchNorm(True, True, 0.9) self.bn2 = hk.BatchNorm(True, True, 0.9) self.bn3 = hk.BatchNorm(True, True, 0.9) self.lstm_fwd = hk.LSTM(lstm_dim) self.lstm_bwd = hk.ResetCore(hk.LSTM(lstm_dim)) self.dropout_rate = dropout_rate
def __call__(self, q: jnp.ndarray, k: jnp.ndarray) -> jnp.ndarray: """Computes the relative position embedding. Args: q: The query. k: The key. Returns: Relative position embedding. """ # Use key instead of query to obtain the length. batch_size, key_length, num_heads, head_dim = list(k.shape) # Content based addressing and global content bias content_score = jnp.einsum('bthd,bThd->bhtT', q + self._r_w_bias, k) # Relative position encoding positional_encodings = self._sinusoidal_pos_emb(key_length, batch_size) positional_encodings = hk.dropout(hk.next_rng_key(), self._dropout_rate, positional_encodings) rel_pos_emb = hk.Conv1D( output_channels=self._dim, kernel_shape=1, with_bias=False, w_init=init.RandomNormal( stddev=self._init_scale))(positional_encodings) rel_pos_emb = jnp.reshape( rel_pos_emb, [batch_size, key_length, num_heads, head_dim]) # Content dependent positional bias and global positional bias rel_pos_score = jnp.einsum('bthd,bThd->bhtT', q + self._r_r_bias, rel_pos_emb) rel_pos_score = relative_shift(rel_pos_score) assert content_score.shape == rel_pos_score.shape return content_score + rel_pos_score
def __init__(self, is_training=True): super().__init__() self.is_training = is_training self.encoder = TokenEncoder(FLAGS.vocab_size, FLAGS.acoustic_encoder_dim, 0.5, is_training) self.decoder = hk.deep_rnn_with_skip_connections([ hk.LSTM(FLAGS.acoustic_decoder_dim), hk.LSTM(FLAGS.acoustic_decoder_dim) ]) self.projection = hk.Linear(FLAGS.mel_dim) # prenet self.prenet_fc1 = hk.Linear(256, with_bias=True) self.prenet_fc2 = hk.Linear(256, with_bias=True) # posnet self.postnet_convs = [hk.Conv1D(FLAGS.postnet_dim, 5) for _ in range(4)] + [hk.Conv1D(FLAGS.mel_dim, 5)] self.postnet_bns = [hk.BatchNorm(True, True, 0.9) for _ in range(4)] + [None]
def __init__(self, word_embedding_matrix, sentence_dim=1024, name="text_module"): """Initialize text module. Args: word_embedding_matrix: 2d matrix [vocab_size, embed_size] to embed words. sentence_dim: dimension of sentence representation. name: module name. """ super(TextModule, self).__init__(name=name) self._word_embedding_module = hk.Embed( embedding_matrix=word_embedding_matrix) self._conv1d_module = hk.Conv1D(sentence_dim, 1, name="text_conv1")
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), name="ResBlock2"): super().__init__(name=name) self.h = h self.convs = [ hk.Conv1D(channels, kernel_size, 1, rate=dilation[i], padding=get_padding(kernel_size, dilation[i])) for i in range(2) ]
def forward(batch, is_training): x, _ = batch batch_size = x.shape[0] x = hk.Embed(vocab_size=max_features, embed_dim=embedding_size)(x) x = hk.Conv1D(output_channels=num_filters, kernel_shape=kernel_size, padding="VALID")(x) if use_swish: x = jax.nn.swish(x) else: x = jax.nn.relu(x) if use_maxpool: x = hk.MaxPool( window_shape=pool_size, strides=pool_size, padding='VALID', channel_axis=2)(x) x = jnp.moveaxis(x, 1, 0)[:, :] #[T, B, F] lstm_layer = hk.LSTM(hidden_size=cell_size) init_state = lstm_layer.initial_state(batch_size) x, state = hk.static_unroll(lstm_layer, x, init_state) x = x[-1] logits = hk.Linear(num_classes)(x) return logits
name="nets.VectorQuantizerEMA", create=lambda: Training(hk.nets.VectorQuantizerEMA(64, 512, 0.25, 0.9)), shape=(BATCH_SIZE, 64)), # TODO(tomhennigan) Make these modules support unbatched input. ModuleDescriptor( name="ConvND", create=lambda: hk.ConvND(1, 3, 3), shape=(BATCH_SIZE, 2, 2)), ModuleDescriptor( name="ConvNDTranspose", create=lambda: hk.ConvNDTranspose(1, 3, 3), shape=(BATCH_SIZE, 2, 2)), ModuleDescriptor( name="Conv1D", create=lambda: hk.Conv1D(3, 3), shape=(BATCH_SIZE, 2, 2)), ModuleDescriptor( name="Conv1DTranspose", create=lambda: hk.Conv1DTranspose(3, 3), shape=(BATCH_SIZE, 2, 2)), ModuleDescriptor( name="Conv2D", create=lambda: hk.Conv2D(3, 3), shape=(BATCH_SIZE, 2, 2, 2)), ModuleDescriptor( name="Conv2DTranspose", create=lambda: hk.Conv2DTranspose(3, 3), shape=(BATCH_SIZE, 2, 2, 2)), ModuleDescriptor( name="Conv3D",
def conv1d_model(inp): return hk.Conv1D(output_channels=1, kernel_shape=2, padding='VALID', stride=1, with_bias=True)(inp)
def conv1d(x, num_units, init_scale=0.02, with_bias=True): return hk.Conv1D(output_channels=num_units, kernel_shape=1, with_bias=with_bias, w_init=init.RandomNormal(stddev=init_scale))(x)