示例#1
0
    def __init__(self, wrnn_dims, fc_dims, cond_channels, global_cond_channels):
        super().__init__()
        conv_channels = 128
        rnn_channels = 512
        self.warmup_steps = 64
        self.conv0 = Conv4(1, conv_channels, global_cond_channels)
        self.conv1 = Conv4(conv_channels, conv_channels, global_cond_channels)
        self.conv2 = Conv4(conv_channels, conv_channels, global_cond_channels)
        self.rnn0 = RNN4(conv_channels + cond_channels, rnn_channels, self.warmup_steps, global_cond_channels)
        self.rnn1 = RNN4(conv_channels + rnn_channels, rnn_channels, self.warmup_steps, global_cond_channels)
        self.rnn2 = RNN4(conv_channels + rnn_channels, rnn_channels, self.warmup_steps, global_cond_channels)
        self.wavernn = WaveRNN(wrnn_dims, fc_dims, rnn_channels + global_cond_channels, 0)

        self.delay_c0 = 9
        self.delay_c1 = self.delay_c0 + 9 * 4
        self.delay_c2 = self.delay_c1 + 9 * 16
        self.delay_r0 = self.delay_c2 + self.warmup_steps * 64
        self.delay_r1 = self.delay_r0 + self.warmup_steps * 16
        self.delay_r2 = self.delay_r1 + self.warmup_steps * 4
        self.delay_wr = self.delay_r2 + self.warmup_steps

        cond_delay = self.delay_wr - self.delay_c2
        if cond_delay % 64 != 0:
            raise RuntimeError(f'Overtone: bad cond delay: {cond_delay}')
        self.cond_pad = cond_delay // 64
示例#2
0
class Model(nn.Module):
    def __init__(self,
                 quantization_channels=256,
                 gru_channels=896,
                 fc_channels=896,
                 lc_channels=80,
                 upsample_factor=(5, 5, 8),
                 use_gru_in_upsample=True):
        super().__init__()

        self.upsample = ConvInUpsampleNetwork(upsample_scales=upsample_factor,
                                              upsample_activation="none",
                                              upsample_activation_params={},
                                              mode="nearest",
                                              cin_channels=lc_channels,
                                              use_gru=use_gru_in_upsample)
 
        self.wavernn = WaveRNN(quantization_channels, gru_channels,
                               fc_channels, lc_channels)

    def forward(self, inputs, conditions):
        conditions = self.upsample(conditions.transpose(1, 2))
        return self.wavernn(inputs, conditions[:, 1:, :])

    def after_update(self):
        self.wavernn.after_update()

    def generate(self, conditions):
        self.eval()
        with torch.no_grad():
            conditions = self.upsample(conditions.transpose(1, 2))
            output = self.wavernn.generate(conditions)
        self.train()
        return output
示例#3
0
 def __init__(self,
              rnn_dims,
              fc_dims,
              pad,
              upsample_factors,
              feat_dims,
              DEVICE="cuda"):
     super().__init__()
     self.n_classes = 256
     self.upsample = UpsampleNetwork(feat_dims,
                                     upsample_factors,
                                     DEVICE=DEVICE)
     self.wavernn = WaveRNN(rnn_dims, fc_dims, feat_dims, 0, DEVICE=DEVICE)
     self.num_params()
     self.DEVICE = DEVICE
示例#4
0
    def __init__(self,
                 quantization_channels=256,
                 gru_channels=896,
                 fc_channels=896,
                 lc_channels=80,
                 upsample_factor=(5, 5, 8),
                 use_gru_in_upsample=True):
        super().__init__()

        self.upsample = ConvInUpsampleNetwork(upsample_scales=upsample_factor,
                                              upsample_activation="none",
                                              upsample_activation_params={},
                                              mode="nearest",
                                              cin_channels=lc_channels,
                                              use_gru=use_gru_in_upsample)
 
        self.wavernn = WaveRNN(quantization_channels, gru_channels,
                               fc_channels, lc_channels)
示例#5
0
class Model(nn.Module):
    def __init__(self,
                 quantization_channels=256,
                 gru_channels=896,
                 fc_channels=896,
                 lc_channels=80,
                 lc_out_channles=80,
                 upsample_factor=(5, 5, 8),
                 use_lstm=True,
                 lstm_layer=2,
                 upsample_method='duplicate'):
        super().__init__()
        self.frame_net = FrameRateNet(lc_channels, lc_out_channles)
        self.upsample = UpsampleNet(input_size=lc_out_channles,
                                    output_size=lc_out_channles,
                                    upsample_factor=upsample_factor,
                                    use_lstm=use_lstm,
                                    lstm_layer=lstm_layer,
                                    upsample_method=upsample_method)
        self.wavernn = WaveRNN(quantization_channels, gru_channels,
                               fc_channels, lc_channels)
        self.num_params()

    def forward(self, inputs, conditions):
        conditions = self.frame_net(conditions.transpose(1, 2))
        conditions = self.upsample(conditions.transpose(1, 2))
        return self.wavernn(inputs, conditions[:, 1:, :])

    def after_update(self):
        self.wavernn.after_update()

    def generate(self, conditions):
        self.eval()
        with torch.no_grad():
            conditions = self.frame_net(conditions.transpose(1, 2))
            conditions = self.upsample(conditions.transpose(1, 2))
            output = self.wavernn.generate(conditions)
        self.train()
        return output

    def num_params(self):
        parameters = filter(lambda p: p.requires_grad, self.parameters())
        parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
        print('Trainable Parameters: %.3f million' % parameters)
示例#6
0
 def __init__(self,
              quantization_channels=256,
              gru_channels=896,
              fc_channels=896,
              lc_channels=80,
              lc_out_channles=80,
              upsample_factor=(5, 5, 8),
              use_lstm=True,
              lstm_layer=2,
              upsample_method='duplicate'):
     super().__init__()
     self.frame_net = FrameRateNet(lc_channels, lc_out_channles)
     self.upsample = UpsampleNet(input_size=lc_out_channles,
                                 output_size=lc_out_channles,
                                 upsample_factor=upsample_factor,
                                 use_lstm=use_lstm,
                                 lstm_layer=lstm_layer,
                                 upsample_method=upsample_method)
     self.wavernn = WaveRNN(quantization_channels, gru_channels,
                            fc_channels, lc_channels)
     self.num_params()
示例#7
0
class Model(nn.Module):
    def __init__(self, rnn_dims, fc_dims, pad, upsample_factors, feat_dims):
        super().__init__()
        self.n_classes = 256
        self.upsample = UpsampleNetwork(feat_dims, upsample_factors)
        self.wavernn = WaveRNN(rnn_dims, fc_dims, feat_dims, 0)
        self.num_params()

    def forward(self, x, mels):
        #logger.log(f'x: {x.size()} mels: {mels.size()}')
        cond = self.upsample(mels)
        #logger.log(f'cond: {cond.size()}')
        return self.wavernn(x, cond.transpose(1, 2), None, None, None)

    def after_update(self):
        self.wavernn.after_update()

    def preview_upsampling(self, mels):
        return self.upsample(mels)

    def forward_generate(self,
                         mels,
                         deterministic=False,
                         use_half=False,
                         verbose=False):
        n = mels.size(0)
        if use_half:
            mels = mels.half()
        self.eval()
        with torch.no_grad():
            cond = self.upsample(mels)
            output = self.wavernn.generate(cond.transpose(1, 2),
                                           None,
                                           None,
                                           None,
                                           use_half=use_half,
                                           verbose=verbose)
        self.train()
        return output

    def num_params(self):
        parameters = filter(lambda p: p.requires_grad, self.parameters())
        parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
        logger.log('Trainable Parameters: %.3f million' % parameters)

    def load_state_dict(self, dict):
        return super().load_state_dict(upgrade_state_dict(dict))

    def do_train(self,
                 paths,
                 dataset,
                 optimiser,
                 epochs,
                 batch_size,
                 step,
                 lr=1e-4,
                 valid_index=[],
                 use_half=False):
        if use_half:
            import apex
            optimiser = apex.fp16_utils.FP16_Optimizer(optimiser,
                                                       dynamic_loss_scale=True)
        for p in optimiser.param_groups:
            p['lr'] = lr
        criterion = nn.NLLLoss().cuda()
        k = 0
        saved_k = 0
        print(win_length, hop_length, win_length / hop_length)

        for e in range(epochs):

            # trn_loader = DataLoader(dataset, collate_fn=lambda batch: env.collate(0, int( win_length/hop_length), 0, batch), batch_size=batch_size,
            #                         num_workers=2, shuffle=True, pin_memory=True)
            trn_loader = DataLoader(
                dataset,
                collate_fn=lambda batch: env.collate(0, 16, 0, batch),
                batch_size=batch_size,
                num_workers=2,
                shuffle=True,
                pin_memory=True)

            start = time.time()
            running_loss_c = 0.
            running_loss_f = 0.

            iters = len(trn_loader)

            for i, (mels, coarse, fine, coarse_f,
                    fine_f) in enumerate(trn_loader):
                mels, coarse, fine, coarse_f, fine_f = mels.cuda(
                ), coarse.cuda(), fine.cuda(), coarse_f.cuda(), fine_f.cuda()
                coarse, fine, coarse_f, fine_f = [
                    t[:, hop_length:1 - hop_length]
                    for t in [coarse, fine, coarse_f, fine_f]
                ]
                if use_half:
                    mels = mels.half()
                    coarse_f = coarse_f.half()
                    fine_f = fine_f.half()

                x = torch.cat([
                    coarse_f[:, :-1].unsqueeze(-1),
                    fine_f[:, :-1].unsqueeze(-1), coarse_f[:, 1:].unsqueeze(-1)
                ],
                              dim=2)

                p_c, p_f, _h_n = self(x, mels)
                loss_c = criterion(p_c.transpose(1, 2).float(), coarse[:, 1:])
                loss_f = criterion(p_f.transpose(1, 2).float(), fine[:, 1:])
                loss = loss_c + loss_f

                optimiser.zero_grad()
                if use_half:
                    optimiser.backward(loss)
                else:
                    loss.backward()
                optimiser.step()
                running_loss_c += loss_c.item()
                running_loss_f += loss_f.item()

                self.after_update()

                speed = (i + 1) / (time.time() - start)
                avg_loss_c = running_loss_c / (i + 1)
                avg_loss_f = running_loss_f / (i + 1)

                step += 1
                k = step // 1000
                logger.status(
                    f'Epoch: {e+1}/{epochs} -- Batch: {i+1}/{iters} -- Loss: c={avg_loss_c:#.4} f={avg_loss_f:#.4} -- Speed: {speed:#.4} steps/sec -- Step: {k}k '
                )

            os.makedirs(paths.checkpoint_dir, exist_ok=True)
            torch.save(self.state_dict(), paths.model_path())
            np.save(paths.step_path(), step)
            logger.log_current_status()
            logger.log(
                f' <saved>; w[0][0] = {self.wavernn.gru.weight_ih_l0[0][0]}')
            if k > saved_k + 50:
                torch.save(self.state_dict(), paths.model_hist_path(step))
                saved_k = k
                self.do_generate(paths,
                                 step,
                                 dataset.path,
                                 valid_index,
                                 use_half=use_half)

    def do_generate(self,
                    paths,
                    step,
                    data_path,
                    test_index,
                    deterministic=False,
                    use_half=False,
                    verbose=False):
        k = step // 1000
        test_mels = [np.load(f'{data_path}/mel/{id}.npy') for id in test_index]
        maxlen = max([x.shape[1] for x in test_mels])
        aligned = [
            torch.cat([
                torch.FloatTensor(x),
                torch.zeros(80, maxlen - x.shape[1] + 1)
            ],
                      dim=1) for x in test_mels
        ]
        print(torch.stack(aligned).size())
        out = self.forward_generate(torch.stack(aligned).cuda(),
                                    deterministic,
                                    use_half=use_half,
                                    verbose=verbose)

        os.makedirs(paths.gen_path(), exist_ok=True)
        for i, id in enumerate(test_index):
            gt = np.load(f'{data_path}/quant/{id}.npy')
            gt = (gt.astype(np.float32) + 0.5) / (2**15 - 0.5)
            librosa.output.write_wav(
                f'{paths.gen_path()}/{k}k_steps_{i}_target.wav',
                gt,
                sr=sample_rate)
            audio = out[i][:len(gt)].cpu().numpy()
            librosa.output.write_wav(
                f'{paths.gen_path()}/{k}k_steps_{i}_generated.wav',
                audio,
                sr=sample_rate)
示例#8
0
class Overtone(nn.Module):
    def __init__(self, wrnn_dims, fc_dims, cond_channels, global_cond_channels):
        super().__init__()
        conv_channels = 128
        rnn_channels = 512
        self.warmup_steps = 64
        self.conv0 = Conv4(1, conv_channels, global_cond_channels)
        self.conv1 = Conv4(conv_channels, conv_channels, global_cond_channels)
        self.conv2 = Conv4(conv_channels, conv_channels, global_cond_channels)
        self.rnn0 = RNN4(conv_channels + cond_channels, rnn_channels, self.warmup_steps, global_cond_channels)
        self.rnn1 = RNN4(conv_channels + rnn_channels, rnn_channels, self.warmup_steps, global_cond_channels)
        self.rnn2 = RNN4(conv_channels + rnn_channels, rnn_channels, self.warmup_steps, global_cond_channels)
        self.wavernn = WaveRNN(wrnn_dims, fc_dims, rnn_channels + global_cond_channels, 0)

        self.delay_c0 = 9
        self.delay_c1 = self.delay_c0 + 9 * 4
        self.delay_c2 = self.delay_c1 + 9 * 16
        self.delay_r0 = self.delay_c2 + self.warmup_steps * 64
        self.delay_r1 = self.delay_r0 + self.warmup_steps * 16
        self.delay_r2 = self.delay_r1 + self.warmup_steps * 4
        self.delay_wr = self.delay_r2 + self.warmup_steps

        cond_delay = self.delay_wr - self.delay_c2
        if cond_delay % 64 != 0:
            raise RuntimeError(f'Overtone: bad cond delay: {cond_delay}')
        self.cond_pad = cond_delay // 64

    def forward(self, x, cond, global_cond):
        n = x.size(0)
        x_coarse = x[:, :, :1]
        c0 = self.conv0(x_coarse, global_cond)
        c1 = self.conv1(c0, global_cond)
        c2 = self.conv2(c1, global_cond)
        r0 = self.rnn0(torch.cat(filter_none([c2, cond]), dim=2), global_cond)[0]
        r1 = self.rnn1(torch.cat([c1[:, (self.delay_r0 - self.delay_c1) // 16:], r0], dim=2), global_cond)[0]
        r2 = self.rnn2(torch.cat([c0[:, (self.delay_r1 - self.delay_c0) // 4:], r1], dim=2), global_cond)[0]
        if global_cond is not None:
            global_cond = global_cond.unsqueeze(1).expand(-1, r2.size(1), -1)
        cond_w = torch.cat(filter_none([r2, global_cond]), dim=2)
        p_c, p_f, _ = self.wavernn(x[:, self.delay_r2:], cond_w, None, None, None)
        return p_c[:, self.warmup_steps:], p_f[:, self.warmup_steps:]

    def generate(self, cond, global_cond, n=None, seq_len=None, verbose=False, use_half=False):
        start = time.time()
        if n is None:
            n = cond.size(0)
        if seq_len is None:
            seq_len = (cond.size(1) - self.cond_pad) * 64
        if use_half:
            std_tensor = torch.tensor([]).cuda().half()
        else:
            std_tensor = torch.tensor([]).cuda()

        # Warmup
        c0 = self.conv0(std_tensor.new_zeros(n, 10, 1), global_cond).repeat(1, 10, 1)
        c1 = self.conv1(c0, global_cond).repeat(1, 10, 1)
        c2 = self.conv2(c1, global_cond)

        if cond is None:
            pad_cond = None
        else:
            pad_cond = cond[:, :self.cond_pad]
        #logger.log(f'pad_cond: {pad_cond.size()}')
        r0, h0 = self.rnn0(torch.cat(filter_none([c2.repeat(1, 85, 1), pad_cond]), dim=2), global_cond)
        r1, h1 = self.rnn1(torch.cat([c1.repeat(1, 9, 1)[:, :84], r0], dim=2), global_cond)
        r2, h2 = self.rnn2(torch.cat([c0.repeat(1, 8, 1), r1], dim=2), global_cond)
        if global_cond is not None:
            global_cond_1 = global_cond.unsqueeze(1).expand(-1, r2.size(1), -1)
        else:
            global_cond_1 = None
        h3 = self.wavernn(std_tensor.new_zeros(n, 64, 3), torch.cat(filter_none([r2, global_cond_1]), dim=2))[2]

        # Create cells
        cell0 = self.rnn0.to_cell()
        cell1 = self.rnn1.to_cell()
        cell2 = self.rnn2.to_cell()
        wcell = self.wavernn.to_cell()

        # Main loop!
        coarse = std_tensor.new_zeros(n, 10, 1)
        c_val = std_tensor.new_zeros(n)
        f_val = std_tensor.new_zeros(n)
        zero = std_tensor.new_zeros(n)
        output = []
        for t in range(seq_len):
            #logger.log(f't = {t}')
            t0 = t % 4
            ct0 = (-t) % 4

            if t0 == 0:
                t1 = (t // 4) % 4
                ct1 = ((-t) // 4) % 4

                #logger.log(f'written to c0[{-ct1-1}]')
                c0[:, -ct1-1].copy_(self.conv0(coarse, global_cond).squeeze(1))
                coarse[:, :-4].copy_(coarse[:, 4:])

                if t1 == 0:
                    t2 = (t // 16) % 4
                    ct2 = ((-t) // 16) % 4

                    #logger.log('read c0')
                    #logger.log(f'written to c1[{-ct2-1}]')
                    c1[:, -ct2-1].copy_(self.conv1(c0, global_cond).squeeze(1))
                    c0[:, :-4].copy_(c0[:, 4:])

                    if t2 == 0:
                        #logger.log('read c1')
                        #logger.log('written to c2')
                        c2 = self.conv2(c1, global_cond).squeeze(1)
                        c1[:, :-4].copy_(c1[:, 4:])

                        #logger.log('read c2')
                        #logger.log('written to r0')
                        if cond is None:
                            inp0 = c2
                        else:
                            inp0 = torch.cat([c2, cond[:, t // 64 + self.cond_pad]], dim=1)
                        r0, h0 = cell0(inp0, global_cond, h0)

                    #logger.log(f'read r0[{t2}]')
                    #logger.log(f'written to r1')
                    #logger.log(f'c1: {c1.size()} r0: {r0.size()}')
                    r1, h1 = cell1(torch.cat([c1[:, -ct2-1], r0[:, t2]], dim=1), global_cond, h1)

                #logger.log(f'read r1[{t1}]')
                #logger.log(f'written to r2')
                #logger.log(f'c0: {c0.size()} r1: {r1.size()}')
                r2, h2 = cell2(torch.cat([c0[:, -ct1-1], r1[:, t1]], dim=1), global_cond, h2)

            #logger.log(f'read r2[{t0}]')
            wcond = torch.cat(filter_none([r2[:, t0], global_cond]), dim=1)

            x = torch.stack([c_val, f_val, zero], dim=1)
            o_c = wcell.forward_c(x, wcond, None, None, h3)
            c_cat = utils.nn.sample_softmax(o_c).float()
            c_val_new = (c_cat / 127.5 - 1.0).to(std_tensor)

            x = torch.stack([c_val, f_val, c_val_new], dim=1)
            o_f, h3 = wcell.forward_f(x, wcond, None, None, h3)
            f_cat = utils.nn.sample_softmax(o_f).float()
            f_val = (f_cat / 127.5 - 1.0).to(std_tensor)
            c_val = c_val_new

            sample = (c_cat * 256 + f_cat) / 32767.5 - 1.0
            coarse[:, 6+t0].copy_(c_val.unsqueeze(1))

            if verbose and t % 10000 < 100:
                logger.log(f'c={c_cat[0]} f={f_cat[0]} sample={sample[0]}')
            output.append(sample)
            if t % 100 == 0 :
                speed = int((t + 1) / (time.time() - start))
                logger.status(f'{t+1}/{seq_len} -- Speed: {speed} samples/sec')

        return torch.stack(output, dim=1)

    def after_update(self):
        self.wavernn.after_update()

    def pad(self):
        return self.delay_wr
示例#9
0
class Overtone(nn.Module):
    def __init__(self, wrnn_dims, fc_dims, cond_channels, global_cond_channels):
        super().__init__()
        conv_channels = 128
        rnn_channels = 512
        self.warmup_steps = 64
        self.conv0 = Conv4(1, conv_channels, global_cond_channels)
        self.conv1 = Conv4(conv_channels, conv_channels, global_cond_channels)
        self.conv2 = Conv4(conv_channels, conv_channels, global_cond_channels)
        self.rnn0 = RNN4(conv_channels + cond_channels, rnn_channels, self.warmup_steps, global_cond_channels)
        self.rnn1 = RNN4(conv_channels + rnn_channels, rnn_channels, self.warmup_steps, global_cond_channels)
        self.rnn2 = RNN4(conv_channels + rnn_channels, rnn_channels, self.warmup_steps, global_cond_channels)
        self.wavernn = WaveRNN(wrnn_dims, fc_dims, rnn_channels + global_cond_channels, 0)

        self.delay_c0 = 9
        self.delay_c1 = self.delay_c0 + 9 * 4
        self.delay_c2 = self.delay_c1 + 9 * 16
        self.delay_r0 = self.delay_c2 + self.warmup_steps * 64
        self.delay_r1 = self.delay_r0 + self.warmup_steps * 16
        self.delay_r2 = self.delay_r1 + self.warmup_steps * 4
        self.delay_wr = self.delay_r2 + self.warmup_steps

        cond_delay = self.delay_wr - self.delay_c2
        if cond_delay % 64 != 0:
            raise RuntimeError(f'Overtone: bad cond delay: {cond_delay}')
        self.cond_pad = cond_delay // 64

    def forward(self, x, cond, global_cond):
        """
        Arguments:
            global_cond -- speaker one-hot embedding
        """
        n = x.size(0)
        x_coarse = x[:, :, :1]
        c0 = self.conv0(x_coarse, global_cond)
        c1 = self.conv1(c0, global_cond)
        c2 = self.conv2(c1, global_cond)
        r0 = self.rnn0(torch.cat(filter_none([c2, cond]), dim=2), global_cond)[0]
        r1 = self.rnn1(torch.cat([c1[:, (self.delay_r0 - self.delay_c1) // 16:], r0], dim=2), global_cond)[0]
        r2 = self.rnn2(torch.cat([c0[:, (self.delay_r1 - self.delay_c0) // 4:], r1], dim=2), global_cond)[0]
        if global_cond is not None:
            global_cond = global_cond.unsqueeze(1).expand(-1, r2.size(1), -1)
        cond_w = torch.cat(filter_none([r2, global_cond]), dim=2)
        p_c, p_f, _ = self.wavernn(x[:, self.delay_r2:], cond_w, None, None, None)
        return p_c[:, self.warmup_steps:], p_f[:, self.warmup_steps:]

    def generate(self, cond, global_cond, n=None, seq_len=None, verbose=False, use_half=False):
        """
        usecase #1: called from vqvae model during test generation
        """
        start = time.time()
        if n is None:
            n = cond.size(0)
        if seq_len is None:
            seq_len = (cond.size(1) - self.cond_pad) * 64
        if use_half:
            std_tensor = torch.tensor([]).cuda().half()
        else:
            std_tensor = torch.tensor([]).cuda()

        # Warmup
        c0 = self.conv0(std_tensor.new_zeros(n, 10, 1), global_cond).repeat(1, 10, 1)
        c1 = self.conv1(c0, global_cond).repeat(1, 10, 1)
        c2 = self.conv2(c1, global_cond)

        if cond is None:
            pad_cond = None
        else:
            pad_cond = cond[:, :self.cond_pad]
        r0, h0 = self.rnn0(torch.cat(filter_none([c2.repeat(1, 85, 1), pad_cond]), dim=2), global_cond)
        r1, h1 = self.rnn1(torch.cat([c1.repeat(1, 9, 1)[:, :84], r0], dim=2), global_cond)
        r2, h2 = self.rnn2(torch.cat([c0.repeat(1, 8, 1), r1], dim=2), global_cond)
        if global_cond is not None:
            global_cond_1 = global_cond.unsqueeze(1).expand(-1, r2.size(1), -1)
        else:
            global_cond_1 = None
        h3 = self.wavernn(std_tensor.new_zeros(n, 64, 3), torch.cat(filter_none([r2, global_cond_1]), dim=2))[2]

        # Create cells
        cell0 = self.rnn0.to_cell()
        cell1 = self.rnn1.to_cell()
        cell2 = self.rnn2.to_cell()
        wcell = self.wavernn.to_cell()

        # Main loop!
        coarse = std_tensor.new_zeros(n, 10, 1)
        c_val = std_tensor.new_zeros(n)
        f_val = std_tensor.new_zeros(n)
        zero = std_tensor.new_zeros(n)
        output = []
        for t in range(seq_len):
            t0 = t % 4
            ct0 = (-t) % 4

            if t0 == 0:
                t1 = (t // 4) % 4
                ct1 = ((-t) // 4) % 4

                # Conv stride4
                c0[:, -ct1-1].copy_(self.conv0(coarse, global_cond).squeeze(1))
                coarse[:, :-4].copy_(coarse[:, 4:])

                if t1 == 0:
                    t2 = (t // 16) % 4
                    ct2 = ((-t) // 16) % 4

                    # Conv stride4
                    c1[:, -ct2-1].copy_(self.conv1(c0, global_cond).squeeze(1))
                    c0[:, :-4].copy_(c0[:, 4:])

                    if t2 == 0:
                        # Conv stride4
                        c2 = self.conv2(c1, global_cond).squeeze(1)
                        c1[:, :-4].copy_(c1[:, 4:])

                        if cond is None:
                            inp0 = c2
                        else:
                            # Time-slice of speech-conditioning??
                            inp0 = torch.cat([c2, cond[:, t // 64 + self.cond_pad]], dim=1)
                        # RNN0 manually looped here (see autoregressive h0)
                        r0, h0 = cell0(inp0, global_cond, h0)

                    # RNN1 manually looped here (see autoregressive h1)
                    r1, h1 = cell1(torch.cat([c1[:, -ct2-1], r0[:, t2]], dim=1), global_cond, h1)

                # RNN2 manually looped here (see autoregressive h2)
                r2, h2 = cell2(torch.cat([c0[:, -ct1-1], r1[:, t1]], dim=1), global_cond, h2)

            # conditioning for WaveRNN
            wcond = torch.cat(filter_none([r2[:, t0], global_cond]), dim=1)

            # WaveRNN w/ dual-softmax
            x = torch.stack([c_val, f_val, zero], dim=1)
            o_c = wcell.forward_c(x, wcond, None, None, h3)
            c_cat = utils.nn.sample_softmax(o_c).float()
            c_val_new = (c_cat / 127.5 - 1.0).to(std_tensor)

            x = torch.stack([c_val, f_val, c_val_new], dim=1)
            o_f, h3 = wcell.forward_f(x, wcond, None, None, h3)
            f_cat = utils.nn.sample_softmax(o_f).float()
            f_val = (f_cat / 127.5 - 1.0).to(std_tensor)
            c_val = c_val_new

            sample = (c_cat * 256 + f_cat) / 32767.5 - 1.0
            coarse[:, 6+t0].copy_(c_val.unsqueeze(1))

            output.append(sample)

        return torch.stack(output, dim=1)

    def after_update(self):
        self.wavernn.after_update()

    def pad(self):
        return self.delay_wr