def __init__(self, h): super().__init__() self.h = h self.num_kernels = len(h.resblock_kernel_sizes) self.num_upsamples = len(h.upsample_rates) self.conv_pre = hk.Conv1D(h.upsample_initial_channel, 7, 1, padding=((3, 3), )) resblock = ResBlock1 if h.resblock == '1' else ResBlock2 self.ups = [] for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): self.ups.append( hk.Conv1DTranspose(h.upsample_initial_channel // (2**(i + 1)), kernel_shape=k, stride=u, padding='SAME', name=f"ups_{i}")) self.resblocks = [] for i in range(len(self.ups)): ch = h.upsample_initial_channel // (2**(i + 1)) for j, (k, d) in enumerate( zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): self.resblocks.append( resblock(h, ch, k, d, name=f'res_block1_{len(self.resblocks)}')) self.conv_post = hk.Conv1D(1, 7, 1, padding=((3, 3), ))
# TODO(tomhennigan) Make these modules support unbatched input. ModuleDescriptor( name="ConvND", create=lambda: hk.ConvND(1, 3, 3), shape=(BATCH_SIZE, 2, 2)), ModuleDescriptor( name="ConvNDTranspose", create=lambda: hk.ConvNDTranspose(1, 3, 3), shape=(BATCH_SIZE, 2, 2)), ModuleDescriptor( name="Conv1D", create=lambda: hk.Conv1D(3, 3), shape=(BATCH_SIZE, 2, 2)), ModuleDescriptor( name="Conv1DTranspose", create=lambda: hk.Conv1DTranspose(3, 3), shape=(BATCH_SIZE, 2, 2)), ModuleDescriptor( name="Conv2D", create=lambda: hk.Conv2D(3, 3), shape=(BATCH_SIZE, 2, 2, 2)), ModuleDescriptor( name="Conv2DTranspose", create=lambda: hk.Conv2DTranspose(3, 3), shape=(BATCH_SIZE, 2, 2, 2)), ModuleDescriptor( name="Conv3D", create=lambda: hk.Conv3D(3, 3), shape=(BATCH_SIZE, 2, 2, 2, 2)), ModuleDescriptor( name="Conv3DTranspose",